refactor: free tagIdConverter

This commit is contained in:
monoid 2022-02-22 17:23:14 +09:00
parent 142ad917bc
commit 84761d23be
2 changed files with 99 additions and 84 deletions

File diff suppressed because one or more lines are too long

View File

@ -5,8 +5,6 @@ from read_data import TagIdConverter
from preprocessing import readPreporcssedDataAll from preprocessing import readPreporcssedDataAll
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
tagIdConverter = TagIdConverter()
class DatasetArray(Dataset): class DatasetArray(Dataset):
def __init__(self, data): def __init__(self, data):
self.x = data self.x = data
@ -39,7 +37,7 @@ def wrap_sentence(tokenizer: PreTrainedTokenizer, sentence):
def wrap_entities(tagIdConverter: TagIdConverter, entities): def wrap_entities(tagIdConverter: TagIdConverter, entities):
return [tagIdConverter.O_id] + entities + [tagIdConverter.O_id] return [tagIdConverter.O_id] + entities + [tagIdConverter.O_id]
def make_collate_fn(tokenizer: PreTrainedTokenizer): def make_collate_fn(tokenizer: PreTrainedTokenizer, tagIdConverter: TagIdConverter):
def ret_fn(batch): def ret_fn(batch):
words = [wrap_sentence(tokenizer,item["ids"]) for item in batch] words = [wrap_sentence(tokenizer,item["ids"]) for item in batch]
entities = [wrap_entities(tagIdConverter,item["entity_ids"]) for item in batch] entities = [wrap_entities(tagIdConverter,item["entity_ids"]) for item in batch]