from typing import Any, List import torch from torch.utils.data import Dataset, DataLoader from read_data import TagIdConverter from preprocessing import readPreporcssedDataAll from transformers import PreTrainedTokenizer tagIdConverter = TagIdConverter() class DatasetArray(Dataset): def __init__(self, data): self.x = data def __len__(self): return len(self.x) def __getitem__(self, idx): return self.x[idx] def get_max_length(data: List[List[Any]]): return max([len(lst) for lst in data]) def make_attention_mask(data: List[List[Any]], max_length = None): if max_length is None: max_length = get_max_length(data) return [[1] * len(lst) + [0] * (max_length - len(lst)) for lst in data] def padding_array(data: List[List[Any]], padding_value = 0, max_length = None): """ padding array of array >>> padding_array([[1,2],[3]]) [[1,2],[3,0]] """ if max_length is None: max_length = get_max_length(data) return [lst + [padding_value] * (max_length - len(lst)) for lst in data] def wrap_sentence(tokenizer: PreTrainedTokenizer, sentence): return [tokenizer.cls_token_id] + sentence + [tokenizer.sep_token_id] def wrap_entities(tagIdConverter: TagIdConverter, entities): return [tagIdConverter.O_id] + entities + [tagIdConverter.O_id] def make_collate_fn(tokenizer: PreTrainedTokenizer): def ret_fn(batch): words = [wrap_sentence(tokenizer,item["ids"]) for item in batch] entities = [wrap_entities(tagIdConverter,item["entity_ids"]) for item in batch] max_length = get_max_length(words) attention_mask = make_attention_mask(words,max_length=max_length) words = padding_array(words,padding_value=tokenizer.pad_token_id , max_length=max_length) entities = padding_array(entities,padding_value=tagIdConverter.pad_id, max_length=max_length) token_type_ids = torch.zeros((len(batch),max_length),dtype=torch.long) return {"input_ids":torch.tensor(words), "attention_mask": torch.tensor(attention_mask), "token_type_ids": token_type_ids}, torch.tensor(entities) return ret_fn if __name__ == "__main__": train, dev, test = readPreporcssedDataAll() print("load transformers...") from transformers import BertTokenizer print("load bert tokenizer...") PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased' tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME) print("test") my_collate_fn = make_collate_fn(tokenizer) print(my_collate_fn(train[0:2])) #train_loader = DataLoader( # DatasetArray(train), # batch_size=1, # shuffle=True, # collate_fn=my_collate_fn #)