diff --git a/ndataset.py b/ndataset.py index c06b7ba..a7469c0 100644 --- a/ndataset.py +++ b/ndataset.py @@ -23,7 +23,7 @@ class NsmcDataset(Dataset): def __getitem__(self, idx): return self.x[idx] -def make_collate_fn(tokenzier: PreTrainedTokenizer): +def make_collate_fn(tokenizer: PreTrainedTokenizer): def collate_fn(batch: List[NsmcRawData]): labels = [s.label for s in batch] return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)