feat: raw nsmc data reader

This commit is contained in:
monoid 2022-02-23 19:22:30 +09:00
parent 406ff30715
commit 54024fa7fd
2 changed files with 62 additions and 0 deletions

View File

@ -1,4 +1,49 @@
from io import TextIOWrapper
from typing import List, Union
import os
import csv
from dataclasses import dataclass
import tqdm
@dataclass
class NsmcRawData:
id: int
document: str
label: int
class NsmcRawDataReader:
def __init__(self, file: Union[str, TextIOWrapper]):
self.fp = file
self.need_close = isinstance(file,str)
if self.need_close:
self.fp = open(file,"r",encoding="utf-8",newline='\n')
self.rd = csv.DictReader(self.fp,delimiter='\t')
def __iter__(self):
mapper = lambda data: NsmcRawData(int(data["id"]),data["document"],int(data["label"]))
return iter(map(mapper,self.rd))
def close(self):
if self.need_close:
self.fp.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def readNsmcRawData(file: Union[str, TextIOWrapper], use_tqdm = False, total: int = 0) -> List[NsmcRawData]:
dataset = []
with NsmcRawDataReader(file) as dataReader:
if use_tqdm and total > 0:
for d in tqdm.tqdm(dataReader, total=total):
dataset.append(d)
else:
for data in dataReader:
dataset.append(data)
return dataset
BASE_PATH = "nsmc/nsmc-master"
if __name__ == "__main__":
dataset = []
raw = readNsmcRawData(f"{BASE_PATH}/ratings.txt", use_tqdm= True, total = 200000)

17
ndata.test.py Normal file
View File

@ -0,0 +1,17 @@
import unittest
from ndata import *
import io
class Testing(unittest.TestCase):
def testcase(self):
text = """id\tdocument\tlabel
20\t사랑해요\t1"""
textfile = io.StringIO(text)
datas = readNsmcRawData(textfile)
i = datas[0]
self.assertEqual(i.id,20)
self.assertEqual(i.document,"사랑해요")
self.assertEqual(i.label,1)
if __name__ == "__main__":
unittest.main()