ner-study/dataset.py

74 lines
2.7 KiB
Python
Raw Normal View History

2022-02-13 17:34:03 +09:00
from typing import Any, List
import torch
from torch.utils.data import Dataset, DataLoader
from read_data import TagIdConverter
from preprocessing import readPreporcssedDataAll
from transformers import PreTrainedTokenizer
tagIdConverter = TagIdConverter()
class DatasetArray(Dataset):
def __init__(self, data):
self.x = data
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx]
def get_max_length(data: List[List[Any]]):
return max([len(lst) for lst in data])
def make_attention_mask(data: List[List[Any]], max_length = None):
if max_length is None:
max_length = get_max_length(data)
return [[1] * len(lst) + [0] * (max_length - len(lst)) for lst in data]
def padding_array(data: List[List[Any]], padding_value = 0, max_length = None):
"""
padding array of array
>>> padding_array([[1,2],[3]])
[[1,2],[3,0]]
"""
if max_length is None:
max_length = get_max_length(data)
return [lst + [padding_value] * (max_length - len(lst)) for lst in data]
def wrap_sentence(tokenizer: PreTrainedTokenizer, sentence):
return [tokenizer.cls_token_id] + sentence + [tokenizer.sep_token_id]
def wrap_entities(tagIdConverter: TagIdConverter, entities):
return [tagIdConverter.O_id] + entities + [tagIdConverter.O_id]
def make_collate_fn(tokenizer: PreTrainedTokenizer):
def ret_fn(batch):
words = [wrap_sentence(tokenizer,item["ids"]) for item in batch]
entities = [wrap_entities(tagIdConverter,item["entity_ids"]) for item in batch]
max_length = get_max_length(words)
attention_mask = make_attention_mask(words,max_length=max_length)
words = padding_array(words,padding_value=tokenizer.pad_token_id , max_length=max_length)
entities = padding_array(entities,padding_value=tagIdConverter.pad_id, max_length=max_length)
token_type_ids = torch.zeros((len(batch),max_length),dtype=torch.long)
return {"input_ids":torch.tensor(words),
"attention_mask": torch.tensor(attention_mask),
"token_type_ids": token_type_ids}, torch.tensor(entities)
return ret_fn
if __name__ == "__main__":
train, dev, test = readPreporcssedDataAll()
print("load transformers...")
from transformers import BertTokenizer
print("load bert tokenizer...")
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
print("test")
my_collate_fn = make_collate_fn(tokenizer)
print(my_collate_fn(train[0:2]))
#train_loader = DataLoader(
# DatasetArray(train),
# batch_size=1,
# shuffle=True,
# collate_fn=my_collate_fn
#)