fix: variable name error

This commit is contained in:
monoid 2022-02-23 20:38:28 +09:00
parent a1f4605d8b
commit 66727770d8

View File

@ -23,7 +23,7 @@ class NsmcDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return self.x[idx] return self.x[idx]
def make_collate_fn(tokenzier: PreTrainedTokenizer): def make_collate_fn(tokenizer: PreTrainedTokenizer):
def collate_fn(batch: List[NsmcRawData]): def collate_fn(batch: List[NsmcRawData]):
labels = [s.label for s in batch] labels = [s.label for s in batch]
return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels) return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)