nsmc-study/ndataset.py

42 lines
1.4 KiB
Python
Raw Normal View History

2022-02-23 19:46:29 +09:00
import sys
from typing import List
from torch.utils.data import Dataset
import torch
from transformers import PreTrainedTokenizer
from ndata import readNsmcRawData, NsmcRawData
def readNsmcDataAll():
"""
Returns: train, test
"""
print("read train set", file=sys.stderr)
train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000)
print("read test set", file=sys.stderr)
test = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
return NsmcDataset(train),NsmcDataset(test)
class NsmcDataset(Dataset):
def __init__(self, data: List[NsmcRawData]):
self.x = data
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx]
def make_collate_fn(tokenzier: PreTrainedTokenizer):
def collate_fn(batch: List[NsmcRawData]):
labels = [s.label for s in batch]
return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)
return collate_fn
if __name__ == "__main__":
from transformers import BertTokenizer
print("load bert tokenizer...")
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
data = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150000)
collate = make_collate_fn(tokenizer)
print(collate(data[0:2]))