nsmc-study/ndata.py

49 lines
1.4 KiB
Python
Raw Normal View History

2022-02-23 19:22:30 +09:00
from io import TextIOWrapper
from typing import List, Union
2022-02-23 17:48:39 +09:00
import os
2022-02-23 19:22:30 +09:00
import csv
from dataclasses import dataclass
import tqdm
@dataclass
class NsmcRawData:
id: int
document: str
label: int
class NsmcRawDataReader:
def __init__(self, file: Union[str, TextIOWrapper]):
self.fp = file
self.need_close = isinstance(file,str)
if self.need_close:
self.fp = open(file,"r",encoding="utf-8",newline='\n')
self.rd = csv.DictReader(self.fp,delimiter='\t')
def __iter__(self):
mapper = lambda data: NsmcRawData(int(data["id"]),data["document"],int(data["label"]))
return iter(map(mapper,self.rd))
def close(self):
if self.need_close:
self.fp.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def readNsmcRawData(file: Union[str, TextIOWrapper], use_tqdm = False, total: int = 0) -> List[NsmcRawData]:
dataset = []
with NsmcRawDataReader(file) as dataReader:
if use_tqdm and total > 0:
for d in tqdm.tqdm(dataReader, total=total):
dataset.append(d)
else:
for data in dataReader:
dataset.append(data)
return dataset
2022-02-23 17:48:39 +09:00
BASE_PATH = "nsmc/nsmc-master"
2022-02-23 19:22:30 +09:00
if __name__ == "__main__":
dataset = []
raw = readNsmcRawData(f"{BASE_PATH}/ratings.txt", use_tqdm= True, total = 200000)