ner-study/EngTraning.ipynb

1116 lines
85 KiB
Plaintext
Raw Permalink Normal View History

2022-02-22 23:36:24 +09:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sys\n",
"sys.executable"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"파이썬 환경 확인.\n",
"envs\\\\nn\\\\python.exe 으로 끝나기를 기대합니다"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from preprocessing import readPreporcssedDataAll\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from dataset import make_collate_fn, DatasetArray\n",
"from transformers import BertTokenizer\n",
"import torch.nn as nn\n",
"from read_data import TagIdConverter"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"TAGS_PATH = \"eng_tags.json\"\n",
"DATASET_PATH = \"engpre\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"변수 설정"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"tagIdConverter = TagIdConverter(TAGS_PATH)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tokenizer 로딩"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from transformers import BertModel"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"버트 로딩"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class MyModel(nn.Module):\n",
" def __init__(self,output_feat: int,bert):\n",
" super().__init__()\n",
" self.bert = bert\n",
" self.dropout = nn.Dropout(p=0.1)\n",
" self.lin = nn.Linear(768,output_feat) #[batch_size,word_size,768] -> [batch_size,word_size,output_feat]\n",
" self.softmax = nn.Softmax(2) #[batch_size,word_size,output_feat] -> [batch_size,word_size,output_feat]\n",
" #0부터 시작해서 2 번째 차원에 softmax.\n",
"\n",
" def forward(self,**kargs):\n",
" emb = self.bert(**kargs)\n",
" e = self.dropout(emb['last_hidden_state'])\n",
" w = self.lin(e)\n",
" return w"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (lin): Linear(in_features=768, out_features=10, bias=True)\n",
" (softmax): Softmax(dim=2)\n",
")\n"
]
}
],
"source": [
"model = MyModel(tagIdConverter.size,bert)\n",
"model.cuda()\n",
"bert.cuda()\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`tagIdConverter.size` 만큼의 종류가 있음"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"bert current device : cuda:0\n"
]
}
],
"source": [
"print(\"bert current device :\",bert.device)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cuda\")\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"datasetTrain, datasetDev, datasetTest = readPreporcssedDataAll(DATASET_PATH)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"my_collate_fn = make_collate_fn(tokenizer, tagIdConverter)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 4\n",
"train_loader = DataLoader(\n",
" DatasetArray(datasetTrain),\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=my_collate_fn\n",
")\n",
"dev_loader = DataLoader(\n",
" DatasetArray(datasetDev),\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=my_collate_fn\n",
")\n",
"test_loader = DataLoader(\n",
" DatasetArray(datasetTest),\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=my_collate_fn\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████| 14041/14041 [00:00<00:00, 468033.78it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"212812/272841 = 0.7799854127495501\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"total_l = 0\n",
"total_o = 0\n",
"\n",
"for item in tqdm(datasetTrain):\n",
" entities = item[\"entity\"]\n",
" l = len(entities)\n",
" o = sum(map(lambda x: 1 if x == \"O\" else 0,entities))\n",
" total_l += l\n",
" total_o += o\n",
"\n",
"print(f\"{total_o}/{total_l} = {total_o/total_l}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"O token 이 77%를 차지함."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from torch.optim import AdamW"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"optimizer = AdamW(model.parameters(), lr=1.0e-5)\n",
"CELoss = nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"from groupby_index import groupby_index"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0: 100%|████████████████████████████████████| 3511/3511 [02:21<00:00, 24.82minibatch/s, accuracy=0.954, loss=1.2]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████████████████████████████| 3511/3511 [02:20<00:00, 24.97minibatch/s, accuracy=0.986, loss=0.288]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 2 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 2: 100%|██████████████████████████████████| 3511/3511 [02:22<00:00, 24.65minibatch/s, accuracy=0.995, loss=0.192]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 3 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 3: 100%|███████████████████████████████████| 3511/3511 [02:25<00:00, 24.20minibatch/s, accuracy=0.99, loss=0.313]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 4 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 4: 100%|██████████████████████████████████| 3511/3511 [02:23<00:00, 24.47minibatch/s, accuracy=0.988, loss=0.345]\n"
]
}
],
"source": [
"TRAIN_EPOCH = 5\n",
"\n",
"result = []\n",
"iteration = 0\n",
"\n",
"t = []\n",
"\n",
"model.zero_grad()\n",
"\n",
"for epoch in range(TRAIN_EPOCH):\n",
" model.train()\n",
" print(f\"epoch {epoch} start:\")\n",
" with tqdm(train_loader, unit=\"minibatch\") as tepoch:\n",
" tepoch.set_description(f\"Epoch {epoch}\")\n",
" \n",
" for batch in groupby_index(tepoch,8):\n",
" corrects = 0\n",
" totals = 0\n",
" losses = 0\n",
" \n",
" optimizer.zero_grad()\n",
" for mini_i,mini_l in batch:\n",
" batch_inputs = {k: v.cuda(device) for k, v in list(mini_i.items())}\n",
" batch_labels = mini_l.cuda(device)\n",
" attention_mask = batch_inputs[\"attention_mask\"]\n",
" \n",
" output = model(**batch_inputs)\n",
" loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n",
" \n",
" prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n",
" corrects += ((prediction == batch_labels.view(-1)) * attention_mask.view(-1)).sum().item()\n",
" totals += attention_mask.view(-1).sum().item()\n",
" losses += loss.item()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" accuracy = corrects / totals\n",
" result.append({\"iter\":iteration,\"loss\":losses,\"accuracy\":accuracy})\n",
" tepoch.set_postfix(loss=losses, accuracy= accuracy)\n",
" iteration += 1"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"del batch_inputs\n",
"del batch_labels\n",
"del loss\n",
"del optimizer"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpu allocated : 1355 MB\n",
"gpu reserved : 1470MB\n"
]
}
],
"source": [
"torch.cuda.empty_cache()\n",
"print(f\"gpu allocated : {torch.cuda.memory_allocated() // 1024**2} MB\")\n",
"print(f\"gpu reserved : {torch.cuda.memory_reserved() // 1024 ** 2}MB\")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAD4CAYAAAAtrdtxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8bUlEQVR4nO3dd5hU1fnA8e/LLktZqiKCgAoIiooiIqJGg0YUiYnYIpiosSEqsUQssQGW2CJWFFEQ9BchFlQCWAgWUHrvyIIIC0ivC+zu7L6/P87MTtnpW5l9P89zn7nl3HPP3J2dd845954rqooxxhiTjGoVXQBjjDGHLgsixhhjkmZBxBhjTNIsiBhjjEmaBRFjjDFJS6/oAoRTrVo1rVWrVkUXwxhjDhn79+9XVS33ikGlDCK1atUiJyenoothjDGHDBE5UBHHteYsY4wxSbMgYowxJmkWRIwxxiQtZp+IiIwALgW2qOrJ3nX/AY73JmkA7FLVDmH2XQvsBQoAj6p2KpVSG2OMqRTi6VgfCbwOvOdboarX+OZF5EVgd5T9z1fVbckW0BhjTOUVszlLVacAO8JtExEB/gSMLuVyGWOMKWUiMkJEtojIkgjbRUReFZEsEVkkIh1j5VnSPpFzgc2quirCdgW+FpG5ItInWkYi0kdE5ojIHI/HU8JiGWOMCWMk0D3K9kuANt6pD/BmrAxLep9Ib6LXQs5R1Y0i0hiYJCIrvDWbYlR1GDAMIDMzM6nx6Z/8/kk6N+vMxcddnMzuld+kSdCqFbRuDYWFsGIFbNoEv/tdyfNetQrWrfPn5XtEgEhwuk2bYPp0uOIKKCiAvDyoVQu++gratoWWLf3pZs2Cyy6DffsgM9PltX+/Sx+Y7+LFsHMnnHeeW96wAebNgz/8weWfkeG2Z2ZC9epu36+/dueiSROoU8edj5Ur3Tm59FK3nJHhXrduhZ9+gjVroGlT6NYNqnl/P33zjcujRQv46CM4+2yYMwd69YJ0779HYSF4PC6/ggIYNgwuuQSOPhpGjIC//AVq1oQDB9w+1av7z+F778FVV7kyi7j3Di7tzp3w3/9Cnz6wezd8/jlcf71Ll5Pj9q9Tx3+e3nwTXnkF/vpXd8xhw1w+2dnuXDz5JPz4I6SlwfLlcPXV7j3VrQudOkGzZu691akD7du7czB+vDvuKafAnXfCwIFw8CD06OH+Hjt2wIwZsHChO0crV0K/fvD9966c/fu79zBnjjsfGRnw29/Cu+/Crl3uczJ2LEyb5tJfdx00aODKM2QIfPstNG4MXbrAli0wYAA895wr45//DDNnunOxY4crkwg0bAiLFrl169bBE0+4c9C0qfs7HXGEO1d5edCokTvnCxa4srz0kvt8LV3qjpmX5z6zr77q3veYMe7z+Le/wfPPu3Kcey4MGgRZWVC7NuzZ485tz57ufU6YAI8+Cscc48779u2ubCNHwosvuvTXXefWf/SRy//aa915ql4dhg6FJUvgpJNg71544IES/jNHpqpTROTYKEkuA95T94yQGSLSQESaquqmaJnGnIBjgSUh69KBzUDzOPMYCPSPJ23t2rU1GbWfrq39v+qf1L5latUq1by8xPa59lrVzEw3v3ev6u7dqu5rpfi0bp1/vwMH3Lrf/MbNd+umOnSo6rx5qtu3hz/Wxo3+vHx8yxMnqv76q+rw4apdu/rXn3GGf37yZP98QYHb//jj3fKnn7rXZ55RnT3bzR92mCvf2rXBx7r9dtUePfzLw4YVf6+PPhq8D6j+8Y+Rz0206bvvYqe55Rb//HnnBW+rVi34vIFqy5aqP/+sevfd/m2NGwfvV7Nm5ON98YXq9df7lw8eVP3gA9WZM5N7jxUxBZ4zmxKfkgTkAnMCpj6q8X2fB2wbD/wmYHky0Clc2qI00TZGOyiuSvR9lH0ygboB89OA7vEcL9kgkvl0pv79y78ntW9C9u1THTzY/4Xp4/G4bYGys91pvvde/7rvvlO96SbVwsLw+e/f7/9A/fRT/B++996LneaSS9zrPfeorl6tet11xdNcdFHF/yPZ5KYrr6z4MthUvlOSgBzV5L7PA7ZNCBNETo+aXxwHHA1sAvKBbOBm7/qRQN+QtEcBE73zrYCF3mkp8Eg8b1A1+SBS95919d4v742dMBm7d6v26qV6wgmq7du7U/fpp8FpbrvNrV+7VvWJJ1TXrw/+cGzY4NK1aOGWv/rK1RBCNW3q32fgwIr/UNtk06EyHXVU6eQTWnsszylJpRRE3gJ6ByyvBJpGzS+eg5b3lGwQqfdMPb37i7uT2jemcH/sMWNUP/rIBY0NG4pvP/LI4uvGj1c97rjiHxrf/JNPVtyH16ZDa/rrXyu+DLGmv/zFNVOW5TE6dFBdsUJ1wQL3vzR4cPE0gc2k8Uyqqt98419+7TX//BlnqB5xRPx5tW2revjh8aV94okSfEWVShD5PfAFIEAXYFbM/OI5aHlPyQaRBs820Lsm3pXUvqqq+q9/qS5dGn5buD94t25l+89hU/lMZ55Z8WWIZxo9WvXBB/3LkT6Xvikjw+1z882x887J8c8HfmGCalpaYuWcN8/f3+XTu7d/+333ub6tkpyLvDzV3Nzg/sBACxeqvvSSP/24ceHzCdc/1b27P5+xY/19ibfeqvrWW/5tdepEL+PVV/vTHjyoWqNG+HS/+517/c1vEvq6Kv4VFTuIhGtZAvr6WpW8wWMIsBpYHKs/RL2fxEo3JRtEGj7bUPtN6JfUvpqf705HvXpueckS1Z073fzrr5fsA29T8tPJJ0feNmlS7P1nzIidprAw8rZdu9xro0aqtWu7+TZtVB97rHjawC/K0On77/3zoR304aaCguDlgQPdZzGwrKrhj9+9u3tt186lCe3ofvVV1SuuUL3wwuC8fF9yqqpHH+3mFy1yZRk/3i2/9prqaaf597vjDv/8DTe4L8RImjRx6UaNcvn69rviisjnIbB/MPRvFg9f+gULwufz97+rrlnj+hNvuSX+fFVV9+xx3xFDhrhWh88/D877r38tvs8nnwSnOXDArS8oKN7HmqB4ayKlPZX7AeOZkg0ihz93uN454c6k9i36JZae7pZBtWFD1R9+iP0PX9mmdu3KNv9ITQPxfDkmMtWsqbp1q3/Zd5ECqP7zn+7v9OyzwfusWeNe+wdcpRftgoM2bVyakSP96z780D+/f7/qL7+obtumunmzu9DBZ+tW/9Vw4P/chJtmzXKvHTu6dDt3uvMY7sKG0Lx69Qr+gglMs2+fq2n07u2Wd+92P4heftlfG/B9ufn6CyZMiJzX3r1ufvNm99n3KSxUnTvX/yWbkeH2y81Vbd7czUeqFQSerzvvdL/K58/3H3vIkPDNuIMGBQfN4cNdmf7xj+jHCeTbN/B44H4kZma6QFCaJkzwN9+98ELx7bm5/jJceWWpHtqCSMCUbBBp9HwjvX387Ynv+OWX/j9sRoZbV5ZfwmU9rV1bNpdZTpvmmiL27Am//fnn3cUGM2e6ZgDf+gED/PP9+8d/PN+vtMBlcMHK43HbQpspVFWXLXNfpIF8gaFDB9Ubb/Sn9wUjVfeFtWWLm2/Z0m0PzSecwGNv2aL6wAPF34vv8mZfEPHZtat4Debyy4Pzfe65yMeL18GDLsBed52bL0leqq7J6uGH3Tm76iqXx7Zt8e8/bZr/2MuWuXVr1qg+/rgW+4L1/b3eeSfxcq5Z465CXLfOf7wuXRLPJ1E//hi+ZhFYwxw8uFQPaUEkYEo2iDR+obH2/W/fxHcM/AeuWbP4ukNh8jXHBX4hrF2r2qpVYvn07Bl5m09eXvD6Cy5wry++WPzcbtrkXn3t7PPmuWaojz8Of4wlS1Rr1XLBKvTvU1BQvLnB9wPgkkui/419zQg9e7pf1OGCSKD1613giceAAe5XcqB4g4hq8aB8003BefgCZmDeDRvGV7ZYkg0igXJyVOfMSWyfX391xx05Mnj9Rx9psSCyapUL/okEqXBWr3YBNPR8ljffOZ87t5SzrZggUimfbJgsQSjUwsR2+uWXkEzE3Slb0e68093RG2/a9DB/ymOOcXdwp6f779D2eeQ
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"iters = [item[\"iter\"] for item in result]\n",
"fig, ax1 = plt.subplots()\n",
"ax1.plot(iters,[item[\"loss\"] for item in result],'g')\n",
"ax2 = ax1.twinx()\n",
"ax2.plot(iters,[item[\"accuracy\"] for item in result],'r')\n",
"plt.xlabel(\"iter\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████| 864/864 [00:10<00:00, 83.69batch/s]\n"
]
}
],
"source": [
"model.eval()\n",
"collect_list = []\n",
"with torch.no_grad():\n",
" with tqdm(test_loader, unit=\"batch\") as tepoch:\n",
" for batch_i,batch_l in tepoch:\n",
" batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n",
" batch_labels = batch_l.cuda(device)\n",
" output = model(**batch_inputs)\n",
" loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n",
" \n",
" prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n",
" correct = (prediction == batch_labels.view(-1)).sum().item()\n",
" accuracy = correct / batch_inputs[\"attention_mask\"].view(-1).sum()\n",
" \n",
" collect_list.append({\"loss\":loss.item(),\"accuracy\":accuracy, \"batch_size\":batch_labels.size(0),\n",
" \"predict\":output.argmax(dim=-1).cpu(),\n",
" \"actual\":batch_labels.cpu(),\n",
" \"attention_mask\":batch_inputs[\"attention_mask\"].cpu()})"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"def getConfusionMatrix(predict,actual,attention_mask):\n",
" ret = torch.zeros((tagIdConverter.size,tagIdConverter.size),dtype=torch.long)\n",
" for i,(p_s,a_s) in enumerate(zip(predict,actual)):\n",
" for j,(p,a) in enumerate(zip(p_s,a_s)):\n",
" ret[p,a] += attention_mask[i,j]\n",
" return ret"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average_loss : 0.13179491325355058, average_accuracy : 0.9739284515380859, size :3453\n"
]
}
],
"source": [
"total_loss = 0\n",
"total_accuracy = 0\n",
"total_size = 0\n",
"confusion = torch.zeros((tagIdConverter.size,tagIdConverter.size),dtype=torch.long)\n",
"\n",
"for item in collect_list:\n",
" batch_size = item[\"batch_size\"]\n",
" total_loss += batch_size * item[\"loss\"]\n",
" total_accuracy += batch_size * item[\"accuracy\"]\n",
" total_size += batch_size\n",
" confusion += getConfusionMatrix(item[\"predict\"],item[\"actual\"],item[\"attention_mask\"])\n",
"print(f\"\"\"average_loss : {total_loss/total_size}, average_accuracy : {total_accuracy/total_size}, size :{total_size}\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 1566, 19, 78, 19, 1, 3, 1, 0, 22],\n",
" [ 0, 30, 599, 53, 3, 1, 11, 1, 1, 76],\n",
" [ 0, 48, 32, 1479, 21, 2, 1, 3, 0, 53],\n",
" [ 0, 7, 14, 17, 1557, 0, 6, 1, 2, 27],\n",
" [ 0, 1, 0, 0, 0, 1510, 47, 101, 26, 45],\n",
" [ 0, 2, 8, 0, 0, 39, 583, 54, 4, 298],\n",
" [ 0, 7, 2, 9, 0, 39, 63, 2560, 32, 122],\n",
" [ 0, 0, 0, 1, 8, 11, 38, 29, 3355, 33],\n",
" [ 0, 7, 28, 24, 9, 25, 86, 45, 15, 55141]])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"confusion = confusion[0:(tagIdConverter.size - 1)][0:(tagIdConverter.size - 1)]"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAEICAYAAADSjgZhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAABO9UlEQVR4nO2ddVxW1x/H319AsKdOVAQ7QDFo7Nmzu2fNmv5Wrjvc5ow5NzvmQqezpptdsxuxZ7eiGBhTQKXO74/nAVGJB3gu5Xn7ui/uc+653++5l8cvpz+ilEKj0Wg0CWOT0QXQaDSazIwOkhqNRpMEOkhqNBpNEuggqdFoNEmgg6RGo9EkgQ6SGo1GkwQ6SGYDxMSvInJbRALSYKeuiJywZtkyChEpKSKhImKb0WXRZG1Ez5PM+ohIXWAu4KqUCsvo8hiNiJwHBiil/snosmiyP7ommT0oBZx/FgKkJYiIXUaXQZN90EEynRGREiKyWERuiMhNEZloTrcRkU9F5IKIXBeRWSLynPlaaRFRItJHRC6KSIiIfGK+1h+YAdQ0Ny+HiUhfEdn2hF8lIuXN5y1E5KiI3BORyyLyrjm9vogExbunkohsEpE7InJERNrEu/abiEwSkRVmO7tFpFwizxxb/pdF5JK5W2CwiPiKyCGz/Ynx8pcTkQ3m9xMiInNEpID52u9ASWCZ+Xnfj2e/v4hcBDbES7MTkUIiEiQirc028orIaRHpndbfp+YZQCmlj3Q6AFvgIPADkAfICdQxX+sHnAbKAnmBxcDv5mulAQX8BOQCqgMPgUrm632BbfH8PPbZnKaA8ubzYKCu+bwg4GU+rw8Emc9zmMvzMWAPNATuYWrSA/wG3AL8ADtgDjAvkeeOLf9U8zM3BR4AfwNFAGfgOvCCOX95oAngADgCW4Af49k7DzROwP4s83vNFS/NzpynKXDV7O8n4M+M/j7oI2scuiaZvvgBxYH3lFJhSqkHSqnYGt9LwFil1FmlVCjwEdDtiabjMKXUfaXUQUzBtnoqyxEJVBaR/Eqp20qpfQnkqYEpWI9USkUopTYAy4Hu8fIsVkoFKKWiMAVJj2T8fm1+5rVAGDBXKXVdKXUZ2Ap4AiilTiul1imlHiqlbgBjgRcseK4vze/1/pMXzD4XAuuBlsArFtjTaHSQTGdKABfMQeVJigMX4n2+gKmGVjRe2tV45+GYglhq6Ai0AC6IyGYRqZlIeS4ppWKeKJNzGspzLd75/QQ+5wUQkSIiMs/cFXAXmA0UTsY2wKVkrk8HqgC/KqVuWmBPo9FBMp25BJRMZGDhCqYBmFhKAlE8HkgsJQzIHftBRIrFv6iU2qOUaoup6fk3sCCR8pQQkfjfkZLA5VSUJ6WMwNRUrqaUyg/0BCTe9cSmZCQ6VcM8FWgapib5kNj+WY0mOXSQTF8CMPUHjhSRPCKSU0Rqm6/NBd4SkTIikhf4FpifSK0zOQ4C7iLiISI5gS9jL4iIvYi8JCLPKaUigbtAdAI2dmMKtu+LSA4RqQ+0BualojwpJR8QCtwREWfgvSeuX8PUd5sSPjb/7AeMAWbpOZQaS9BBMh1RSkVjCjTlgYtAENDVfPkX4HdMgxTnMA1svJ5KPyeBr4B/gFPAtiey9ALOm5uygzHV1J60EQG0AZoDIcBkoLdS6nhqypRChgFewH/ACkyDWPEZAXxqHhV/NzljIuINvI2p/NHAKEy1zg+tWmpNtkRPJtdoNJok0DVJjUajSQIdJDUajSYJdJDUaDSaJNBBUqPRaJLAkI0AChcurEqVKm2EaY1Gkw5cuHCekJAQST5n4tjmL6VU1FOLnxJE3b+xRinVLC3+jMKQIFmqVGm27w40wrRGo0kHavv7pNmGinqAg1s3i/I+2D/BkhVVGYLeUkqj0RiDAJKmymimQAdJjUZjHJL1hz10kNRoNMaRDWqSGRrm165ZTTV3V9zdyvPd6JHafgb40Pazt/308pEwAja2lh2ZGSM2qfTy8lb3I1WSR+iDKFWmbFl19MQZ9V/YQ1W1ajW17+CRZO+z9Mjq9rPDM2j7Wfc75OXlrdIaByRPUZXT/32LDiAwozfXzXSb7u4JCKBcufKUKVsWe3t7OnftxvJlS7T9dPSh7Wdv++nlI3HE1Ny25MjEZFiQvHLlMi4uJeI+Ozu7cPmy9bYqzOr208OHtp+97aeXjyQRG8uOTIxFpRORZiJywiyeZJXtpRLafUis+Bclq9tPDx/afva2n14+kuRZqEmaNyadhGlfwcpAdxGpnFbHzs4uBAU92m3/8uUgihcvnlaz2cZ+evjQ9rO3/fTykTjyzNQk/YDTZoGqCEw7U7dNq2MfX19Onz7F+XPniIiIYOH8ebRs1Sb5G58R++nhQ9vP3vbTy0eiCNlidNuSeZLOPC6wFAT4P5lJRAYBgwBKlCyZvGM7O34YN5HWLV8kOjqaPn37Udnd3bJSW0BWt58ePrT97G0/vXwkjmT6WqIlJLszuYh0Bl5USg0wf+4F+CmlEpUW8Pb2UXrttkaTdant78PevYFp6iy0yeesHHwGW5T3wabP9yql0r5g3AAsqUkGYZJCjcUFk5KeRqPRJI6QLWqSljzBHqCCWcXPHugGLDW2WBqNJluQDUa3k61JKqWiROQ1YA1gC/yilDpieMk0Gk0WRzL9oIwlWLTBhVJqJbDS4LJoNJrsRjZobutdgDQajTFkgaa0JeggqdFojCMb1CSz/hNoNJrMi5UGbkQkp4gEiMhBETkiIsPM6YVEZJ2InDL/LBjvno/MS6lPiMiL8dK9ReSw+dp4SWadpg6SGo3GIKy6LPEh0FApVR3wAJqJSA3gQ2C9UqoCsN78GfPS6W6AO9AMmGxeYg0wBdPClwrmI0kBMh0kNRqNMVhxWaIyEWr+mMN8KExLpGea02cC7cznbYF5SqmHSqlzwGnAT0ScgPxKqZ3KtJJmVrx7EkQHSY1GYxApqkkWFpHAeMegp6yJ2IrIAeA6sE4ptRsoqpQKBjD/LGLOntByamfzEZRAeqJkyYGbYWtPGGr/8yYVDbUPEB2T9HLQtGJrY+yoYrput6VJkOSWFKfJtrUMWf49CUluWaJSKhrwEJECwF8iUiUpzwmZSCI9UXRNUqPRGIcBW6Uppe4AmzD1JV4zN6Ex/7xuzpbYcuog8/mT6Ymig6RGozEO641uO5prkIhILqAxcBzTEuk+5mx9gFhtiqVANxFxEJEymAZoAsxN8nsiUsM8qt073j0JkiWb2xqNJgsgVt0qzQmYaR6htgEWKKWWi8hOYIGI9AcuAp0BlFJHRGQBcBSIAl41N9cBhgC/AbmAVeYjUXSQ1Gg0hiE21gmSSqlDgGcC6TeBRoncMxwYnkB6IJBUf+ZjZEnd7RU/fsz4HrWY8b/WcWlb50xgYu96/PJaO355rR1n9myOu3b93AlmvdOVGUNa8fP/WhMV8RCA6MgIVo3/jGkDX2T6K805vn1Nsr5fGdiPUs5F8fGoGpd26OBB6tetha9nNTq2a8Pdu3ctfpYnOXnyBLX8vOKO4o4FmDRhHIcOHqBBvVrU8vOiXi0/AvcEpMq+0eV/zNeAfpQsXgRvD4u/jyniwYMH1Knph59Xdbyqu/P1sC+s7sPoZ7C2JnbQpUs0a9IQz6qV8a5ehUkTxgHw8Yfv4VGlEn5e1enaqQN37txJs6/kEEwDfJYcmZkMC5LR0dEMfeNVlixbxf5DR1k4by7Hjh616N6qjdvT5aufnkr3bduHfhP/pt/Evynn+wIAMdFRLBvzHi++OowBU5bTY+QsbGxNFegd86eSp8DzvPLTGgZOWUHJKn7J+u7Vuy9/L3+8dv6/wQP5evgI9uw/RJt27fjh++8seo6EqFjRlR0B+9gRsI+tO/eQK3duWrdpx2cff8BHn3zGjoB9fPL5l3z2cer02Iwu/2O++vRlyfLVVrGVEA4ODqxet4GAfQfZHXiAtWtWs3vXLqv6MPIZ0vJ/IDFs7ewYMXoM+w8fZdO2nUybMpljR4/SsFETAg8cJmDfQSpUqMCYUSOs9BRJICk4MjFZUne7ZBVfcuZ7zqK85/Ztp0hpV4qWdQMgV/6
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import itertools\n",
"\n",
"plt.title(\"confusion matrix\")\n",
"plt.imshow(confusion,cmap='Blues')\n",
"\n",
"plt.colorbar()\n",
"for i,j in itertools.product(range(confusion.shape[0]),range(confusion.shape[1])):\n",
" plt.text(j,i,\"{:}\".format(confusion[i,j]),horizontalalignment=\"center\",color=\"black\" if i == j else \"black\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"def getF1Score(confusion,c):\n",
" TP = confusion[c,c]\n",
" FP = confusion[c].sum() - TP\n",
" FN = confusion[:,c].sum() - TP\n",
" precision = TP / (TP + FP)\n",
" recall = TP / (TP + FN)\n",
"\n",
" f1Score = (2*precision*recall)/(precision + recall)\n",
" return f1Score"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"class 0 f1 score : nan\n",
"class 1 f1 score : 0.9293768405914307\n",
"class 2 f1 score : 0.8267770409584045\n",
"class 3 f1 score : 0.9029303789138794\n",
"class 4 f1 score : 0.9614078402519226\n",
"class 5 f1 score : 0.9060906171798706\n",
"class 6 f1 score : 0.6701149344444275\n",
"class 7 f1 score : 0.9169055223464966\n",
"class 8 f1 score : 0.9731689691543579\n"
]
}
],
"source": [
"for i in range(tagIdConverter.size - 1):\n",
" f1 = getF1Score(confusion,i)\n",
" print(f\"class {i} f1 score : {f1}\")"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"import collections"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"20744it [00:00, 170034.81it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"token \t count frequency%\n",
"O(9) \t 314867 77.841%\n",
"I-PER(8) \t 21118 5.221%\n",
"I-ORG(7) \t 16329 4.037%\n",
"I-LOC(5) \t 10922 2.700%\n",
"B-LOC(1) \t 10645 2.632%\n",
"B-PER(4) \t 10059 2.487%\n",
"B-ORG(3) \t 9322 2.305%\n",
"I-MISC(6) \t 6176 1.527%\n",
"B-MISC(2) \t 5062 1.251%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"counter = collections.Counter()\n",
"total_l = 0\n",
"\n",
"for item in tqdm(itertools.chain(datasetTrain,datasetDev,datasetTest)):\n",
" entities = item[\"entity\"]\n",
" for entity in entities:\n",
" counter[entity] += 1\n",
" total_l += len(entities)\n",
"print(f\"{'token':<12}\\t{'count':>12} {'frequency%':>12}\")\n",
"for token,count in counter.most_common():\n",
" tid = tagIdConverter.convert_tokens_to_ids([token])[0]\n",
" print(f\"{f'{token}({tid})':<12}\\t{count:>12}{count*100/total_l:>12.3f}%\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}