fix: variable name error
This commit is contained in:
parent
a1f4605d8b
commit
66727770d8
@ -23,7 +23,7 @@ class NsmcDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.x[idx]
|
return self.x[idx]
|
||||||
|
|
||||||
def make_collate_fn(tokenzier: PreTrainedTokenizer):
|
def make_collate_fn(tokenizer: PreTrainedTokenizer):
|
||||||
def collate_fn(batch: List[NsmcRawData]):
|
def collate_fn(batch: List[NsmcRawData]):
|
||||||
labels = [s.label for s in batch]
|
labels = [s.label for s in batch]
|
||||||
return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)
|
return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)
|
||||||
|
Loading…
Reference in New Issue
Block a user