ner-study/preprocessing.py

80 lines
2.7 KiB
Python
Raw Permalink Normal View History

2022-02-13 17:34:03 +09:00
2022-02-22 18:33:29 +09:00
import argparse
import os
2022-02-22 17:45:53 +09:00
import sys
2022-02-22 18:33:29 +09:00
from read_data import TagIdConverter, make_long_namedEntity, readEnglishDataAll, readKoreanDataAll, Sentence
from typing import Any, List
2022-02-13 17:34:03 +09:00
import json
import tqdm
from transformers import PreTrainedTokenizer
2022-02-22 17:26:11 +09:00
PRE_BASE_PATH = 'prepro'
2022-02-13 17:34:03 +09:00
2022-02-22 17:26:11 +09:00
def preprocessing(tokenizer : PreTrainedTokenizer, converter :TagIdConverter,dataset: List[Sentence]):
2022-02-13 17:34:03 +09:00
ret = []
for item in tqdm.tqdm(dataset):
assert len(item.word) == len(item.detail)
tokens = tokenizer.tokenize(" ".join(item.word))
e = make_long_namedEntity(item.word,tokens,item.detail)
if len(e) != len(tokens):
print(e,tokens)
assert len(e) == len(tokens)
ids = tokenizer.convert_tokens_to_ids(tokens)
entityIds = converter.convert_tokens_to_ids(e)
ret.append({"tokens":tokens,"ids":ids,"entity":e,"entity_ids": entityIds})
return ret
def saveObject(path: str,data: Any):
with open(path,"w",encoding="utf-8") as fp:
json.dump(data,fp,ensure_ascii=False, indent=2)
def readPreprocessedData(path: str):
with open(path,"r", encoding="utf-8") as fp:
return json.load(fp)
2022-02-22 18:33:29 +09:00
def readPreporcssedDataAll(path = PRE_BASE_PATH):
train = readPreprocessedData(os.path.join(path,"train.json"))
dev = readPreprocessedData(os.path.join(path,"dev.json"))
test = readPreprocessedData(os.path.join(path,"test.json"))
2022-02-13 17:34:03 +09:00
return train, dev, test
if __name__ == "__main__":
2022-02-22 18:33:29 +09:00
parser = argparse.ArgumentParser()
parser.add_argument("--kind", default="korean")
parser.add_argument("path",default=PRE_BASE_PATH,help="directory path of processed data")
parser.add_argument("--tag", default="tags.json",help="path of tag description")
args = parser.parse_args()
dirPath = args.path
if args.kind == "korean":
rawTrain, rawDev, rawTest = readKoreanDataAll()
elif args.kind == "english":
rawTrain, rawDev, rawTest = readEnglishDataAll()
else:
print("unknown language",file=sys.stderr)
exit(1)
converter = TagIdConverter(args.tag)
os.makedirs(dirPath)
2022-02-13 17:34:03 +09:00
from transformers import BertTokenizer
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
2022-02-22 17:45:53 +09:00
print("load tokenzier...",file=sys.stderr)
2022-02-13 17:34:03 +09:00
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
2022-02-22 17:45:53 +09:00
print("process train...",file=sys.stderr)
2022-02-22 17:26:11 +09:00
train = preprocessing(tokenizer,converter,rawTrain)
2022-02-22 18:33:29 +09:00
saveObject(path.join(dirPath,"train.json"),train)
2022-02-22 17:45:53 +09:00
print("process dev...",file=sys.stderr)
2022-02-22 17:26:11 +09:00
dev = preprocessing(tokenizer,converter,rawDev)
2022-02-22 18:33:29 +09:00
saveObject(path.join(dirPath,"dev.json"),dev)
2022-02-22 17:45:53 +09:00
print("process test...",file=sys.stderr)
2022-02-22 17:26:11 +09:00
test = preprocessing(tokenizer,converter,rawTest)
2022-02-22 18:33:29 +09:00
saveObject(path.join(dirPath,"test.json"),test)