72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
from typing import Any, List
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from read_data import TagIdConverter
|
|
from preprocessing import readPreporcssedDataAll
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
class DatasetArray(Dataset):
|
|
def __init__(self, data):
|
|
self.x = data
|
|
def __len__(self):
|
|
return len(self.x)
|
|
def __getitem__(self, idx):
|
|
return self.x[idx]
|
|
|
|
def get_max_length(data: List[List[Any]]):
|
|
return max([len(lst) for lst in data])
|
|
|
|
def make_attention_mask(data: List[List[Any]], max_length = None):
|
|
if max_length is None:
|
|
max_length = get_max_length(data)
|
|
return [[1] * len(lst) + [0] * (max_length - len(lst)) for lst in data]
|
|
|
|
def padding_array(data: List[List[Any]], padding_value = 0, max_length = None):
|
|
"""
|
|
padding array of array
|
|
|
|
>>> padding_array([[1,2],[3]])
|
|
[[1,2],[3,0]]
|
|
"""
|
|
if max_length is None:
|
|
max_length = get_max_length(data)
|
|
return [lst + [padding_value] * (max_length - len(lst)) for lst in data]
|
|
|
|
def wrap_sentence(tokenizer: PreTrainedTokenizer, sentence):
|
|
return [tokenizer.cls_token_id] + sentence + [tokenizer.sep_token_id]
|
|
def wrap_entities(tagIdConverter: TagIdConverter, entities):
|
|
return [tagIdConverter.O_id] + entities + [tagIdConverter.O_id]
|
|
|
|
def make_collate_fn(tokenizer: PreTrainedTokenizer, tagIdConverter: TagIdConverter):
|
|
def ret_fn(batch):
|
|
words = [wrap_sentence(tokenizer,item["ids"]) for item in batch]
|
|
entities = [wrap_entities(tagIdConverter,item["entity_ids"]) for item in batch]
|
|
max_length = get_max_length(words)
|
|
|
|
attention_mask = make_attention_mask(words,max_length=max_length)
|
|
|
|
words = padding_array(words,padding_value=tokenizer.pad_token_id , max_length=max_length)
|
|
entities = padding_array(entities,padding_value=tagIdConverter.pad_id, max_length=max_length)
|
|
token_type_ids = torch.zeros((len(batch),max_length),dtype=torch.long)
|
|
return {"input_ids":torch.tensor(words),
|
|
"attention_mask": torch.tensor(attention_mask),
|
|
"token_type_ids": token_type_ids}, torch.tensor(entities)
|
|
return ret_fn
|
|
|
|
if __name__ == "__main__":
|
|
train, dev, test = readPreporcssedDataAll()
|
|
|
|
print("load transformers...")
|
|
from transformers import BertTokenizer
|
|
print("load bert tokenizer...")
|
|
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
|
|
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
|
|
print("test")
|
|
my_collate_fn = make_collate_fn(tokenizer)
|
|
print(my_collate_fn(train[0:2]))
|
|
#train_loader = DataLoader(
|
|
# DatasetArray(train),
|
|
# batch_size=1,
|
|
# shuffle=True,
|
|
# collate_fn=my_collate_fn
|
|
#) |