From 46f8d08fd7442e5655c3aaa5fc5d9a319d5a3518 Mon Sep 17 00:00:00 2001 From: monoid Date: Tue, 22 Feb 2022 23:36:24 +0900 Subject: [PATCH] feat: training english --- EngTraning.ipynb | 1115 ++++++++++++++++++++++++++++++++++++++++++++++ Training.ipynb | 1 - 2 files changed, 1115 insertions(+), 1 deletion(-) create mode 100644 EngTraning.ipynb diff --git a/EngTraning.ipynb b/EngTraning.ipynb new file mode 100644 index 0000000..2199cd9 --- /dev/null +++ b/EngTraning.ipynb @@ -0,0 +1,1115 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sys\n", + "sys.executable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "파이썬 환경 확인.\n", + "envs\\\\nn\\\\python.exe 으로 끝나기를 기대합니다" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from preprocessing import readPreporcssedDataAll\n", + "import torch\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from dataset import make_collate_fn, DatasetArray\n", + "from transformers import BertTokenizer\n", + "import torch.nn as nn\n", + "from read_data import TagIdConverter" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "TAGS_PATH = \"eng_tags.json\"\n", + "DATASET_PATH = \"engpre\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "변수 설정" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "tagIdConverter = TagIdConverter(TAGS_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n", + "tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Tokenizer 로딩" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import BertModel" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], + "source": [ + "PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n", + "bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "버트 로딩" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class MyModel(nn.Module):\n", + " def __init__(self,output_feat: int,bert):\n", + " super().__init__()\n", + " self.bert = bert\n", + " self.dropout = nn.Dropout(p=0.1)\n", + " self.lin = nn.Linear(768,output_feat) #[batch_size,word_size,768] -> [batch_size,word_size,output_feat]\n", + " self.softmax = nn.Softmax(2) #[batch_size,word_size,output_feat] -> [batch_size,word_size,output_feat]\n", + " #0부터 시작해서 2 번째 차원에 softmax.\n", + "\n", + " def forward(self,**kargs):\n", + " emb = self.bert(**kargs)\n", + " e = self.dropout(emb['last_hidden_state'])\n", + " w = self.lin(e)\n", + " return w" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MyModel(\n", + " (bert): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(119547, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (lin): Linear(in_features=768, out_features=10, bias=True)\n", + " (softmax): Softmax(dim=2)\n", + ")\n" + ] + } + ], + "source": [ + "model = MyModel(tagIdConverter.size,bert)\n", + "model.cuda()\n", + "bert.cuda()\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`tagIdConverter.size` 만큼의 종류가 있음" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bert current device : cuda:0\n" + ] + } + ], + "source": [ + "print(\"bert current device :\",bert.device)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda')" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = torch.device(\"cuda\")\n", + "device" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "datasetTrain, datasetDev, datasetTest = readPreporcssedDataAll(DATASET_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "my_collate_fn = make_collate_fn(tokenizer, tagIdConverter)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_SIZE = 4\n", + "train_loader = DataLoader(\n", + " DatasetArray(datasetTrain),\n", + " batch_size=BATCH_SIZE,\n", + " shuffle=True,\n", + " collate_fn=my_collate_fn\n", + ")\n", + "dev_loader = DataLoader(\n", + " DatasetArray(datasetDev),\n", + " batch_size=BATCH_SIZE,\n", + " shuffle=True,\n", + " collate_fn=my_collate_fn\n", + ")\n", + "test_loader = DataLoader(\n", + " DatasetArray(datasetTest),\n", + " batch_size=BATCH_SIZE,\n", + " shuffle=True,\n", + " collate_fn=my_collate_fn\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████| 14041/14041 [00:00<00:00, 468033.78it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "212812/272841 = 0.7799854127495501\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "total_l = 0\n", + "total_o = 0\n", + "\n", + "for item in tqdm(datasetTrain):\n", + " entities = item[\"entity\"]\n", + " l = len(entities)\n", + " o = sum(map(lambda x: 1 if x == \"O\" else 0,entities))\n", + " total_l += l\n", + " total_o += o\n", + "\n", + "print(f\"{total_o}/{total_l} = {total_o/total_l}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "O token 이 77%를 차지함." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.optim import AdamW" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = AdamW(model.parameters(), lr=1.0e-5)\n", + "CELoss = nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "from groupby_index import groupby_index" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0 start:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0: 100%|████████████████████████████████████| 3511/3511 [02:21<00:00, 24.82minibatch/s, accuracy=0.954, loss=1.2]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 1 start:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 1: 100%|██████████████████████████████████| 3511/3511 [02:20<00:00, 24.97minibatch/s, accuracy=0.986, loss=0.288]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 2 start:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 2: 100%|██████████████████████████████████| 3511/3511 [02:22<00:00, 24.65minibatch/s, accuracy=0.995, loss=0.192]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 3 start:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 3: 100%|███████████████████████████████████| 3511/3511 [02:25<00:00, 24.20minibatch/s, accuracy=0.99, loss=0.313]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 4 start:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 4: 100%|██████████████████████████████████| 3511/3511 [02:23<00:00, 24.47minibatch/s, accuracy=0.988, loss=0.345]\n" + ] + } + ], + "source": [ + "TRAIN_EPOCH = 5\n", + "\n", + "result = []\n", + "iteration = 0\n", + "\n", + "t = []\n", + "\n", + "model.zero_grad()\n", + "\n", + "for epoch in range(TRAIN_EPOCH):\n", + " model.train()\n", + " print(f\"epoch {epoch} start:\")\n", + " with tqdm(train_loader, unit=\"minibatch\") as tepoch:\n", + " tepoch.set_description(f\"Epoch {epoch}\")\n", + " \n", + " for batch in groupby_index(tepoch,8):\n", + " corrects = 0\n", + " totals = 0\n", + " losses = 0\n", + " \n", + " optimizer.zero_grad()\n", + " for mini_i,mini_l in batch:\n", + " batch_inputs = {k: v.cuda(device) for k, v in list(mini_i.items())}\n", + " batch_labels = mini_l.cuda(device)\n", + " attention_mask = batch_inputs[\"attention_mask\"]\n", + " \n", + " output = model(**batch_inputs)\n", + " loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n", + " \n", + " prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n", + " corrects += ((prediction == batch_labels.view(-1)) * attention_mask.view(-1)).sum().item()\n", + " totals += attention_mask.view(-1).sum().item()\n", + " losses += loss.item()\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + " accuracy = corrects / totals\n", + " result.append({\"iter\":iteration,\"loss\":losses,\"accuracy\":accuracy})\n", + " tepoch.set_postfix(loss=losses, accuracy= accuracy)\n", + " iteration += 1" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "del batch_inputs\n", + "del batch_labels\n", + "del loss\n", + "del optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gpu allocated : 1355 MB\n", + "gpu reserved : 1470MB\n" + ] + } + ], + "source": [ + "torch.cuda.empty_cache()\n", + "print(f\"gpu allocated : {torch.cuda.memory_allocated() // 1024**2} MB\")\n", + "print(f\"gpu reserved : {torch.cuda.memory_reserved() // 1024 ** 2}MB\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "iters = [item[\"iter\"] for item in result]\n", + "fig, ax1 = plt.subplots()\n", + "ax1.plot(iters,[item[\"loss\"] for item in result],'g')\n", + "ax2 = ax1.twinx()\n", + "ax2.plot(iters,[item[\"accuracy\"] for item in result],'r')\n", + "plt.xlabel(\"iter\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 864/864 [00:10<00:00, 83.69batch/s]\n" + ] + } + ], + "source": [ + "model.eval()\n", + "collect_list = []\n", + "with torch.no_grad():\n", + " with tqdm(test_loader, unit=\"batch\") as tepoch:\n", + " for batch_i,batch_l in tepoch:\n", + " batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n", + " batch_labels = batch_l.cuda(device)\n", + " output = model(**batch_inputs)\n", + " loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n", + " \n", + " prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n", + " correct = (prediction == batch_labels.view(-1)).sum().item()\n", + " accuracy = correct / batch_inputs[\"attention_mask\"].view(-1).sum()\n", + " \n", + " collect_list.append({\"loss\":loss.item(),\"accuracy\":accuracy, \"batch_size\":batch_labels.size(0),\n", + " \"predict\":output.argmax(dim=-1).cpu(),\n", + " \"actual\":batch_labels.cpu(),\n", + " \"attention_mask\":batch_inputs[\"attention_mask\"].cpu()})" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "def getConfusionMatrix(predict,actual,attention_mask):\n", + " ret = torch.zeros((tagIdConverter.size,tagIdConverter.size),dtype=torch.long)\n", + " for i,(p_s,a_s) in enumerate(zip(predict,actual)):\n", + " for j,(p,a) in enumerate(zip(p_s,a_s)):\n", + " ret[p,a] += attention_mask[i,j]\n", + " return ret" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "average_loss : 0.13179491325355058, average_accuracy : 0.9739284515380859, size :3453\n" + ] + } + ], + "source": [ + "total_loss = 0\n", + "total_accuracy = 0\n", + "total_size = 0\n", + "confusion = torch.zeros((tagIdConverter.size,tagIdConverter.size),dtype=torch.long)\n", + "\n", + "for item in collect_list:\n", + " batch_size = item[\"batch_size\"]\n", + " total_loss += batch_size * item[\"loss\"]\n", + " total_accuracy += batch_size * item[\"accuracy\"]\n", + " total_size += batch_size\n", + " confusion += getConfusionMatrix(item[\"predict\"],item[\"actual\"],item[\"attention_mask\"])\n", + "print(f\"\"\"average_loss : {total_loss/total_size}, average_accuracy : {total_accuracy/total_size}, size :{total_size}\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [ 0, 1566, 19, 78, 19, 1, 3, 1, 0, 22],\n", + " [ 0, 30, 599, 53, 3, 1, 11, 1, 1, 76],\n", + " [ 0, 48, 32, 1479, 21, 2, 1, 3, 0, 53],\n", + " [ 0, 7, 14, 17, 1557, 0, 6, 1, 2, 27],\n", + " [ 0, 1, 0, 0, 0, 1510, 47, 101, 26, 45],\n", + " [ 0, 2, 8, 0, 0, 39, 583, 54, 4, 298],\n", + " [ 0, 7, 2, 9, 0, 39, 63, 2560, 32, 122],\n", + " [ 0, 0, 0, 1, 8, 11, 38, 29, 3355, 33],\n", + " [ 0, 7, 28, 24, 9, 25, 86, 45, 15, 55141]])" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "confusion" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "confusion = confusion[0:(tagIdConverter.size - 1)][0:(tagIdConverter.size - 1)]" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import itertools\n", + "\n", + "plt.title(\"confusion matrix\")\n", + "plt.imshow(confusion,cmap='Blues')\n", + "\n", + "plt.colorbar()\n", + "for i,j in itertools.product(range(confusion.shape[0]),range(confusion.shape[1])):\n", + " plt.text(j,i,\"{:}\".format(confusion[i,j]),horizontalalignment=\"center\",color=\"black\" if i == j else \"black\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "def getF1Score(confusion,c):\n", + " TP = confusion[c,c]\n", + " FP = confusion[c].sum() - TP\n", + " FN = confusion[:,c].sum() - TP\n", + " precision = TP / (TP + FP)\n", + " recall = TP / (TP + FN)\n", + "\n", + " f1Score = (2*precision*recall)/(precision + recall)\n", + " return f1Score" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class 0 f1 score : nan\n", + "class 1 f1 score : 0.9293768405914307\n", + "class 2 f1 score : 0.8267770409584045\n", + "class 3 f1 score : 0.9029303789138794\n", + "class 4 f1 score : 0.9614078402519226\n", + "class 5 f1 score : 0.9060906171798706\n", + "class 6 f1 score : 0.6701149344444275\n", + "class 7 f1 score : 0.9169055223464966\n", + "class 8 f1 score : 0.9731689691543579\n" + ] + } + ], + "source": [ + "for i in range(tagIdConverter.size - 1):\n", + " f1 = getF1Score(confusion,i)\n", + " print(f\"class {i} f1 score : {f1}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "import collections" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "20744it [00:00, 170034.81it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "token \t count frequency%\n", + "O(9) \t 314867 77.841%\n", + "I-PER(8) \t 21118 5.221%\n", + "I-ORG(7) \t 16329 4.037%\n", + "I-LOC(5) \t 10922 2.700%\n", + "B-LOC(1) \t 10645 2.632%\n", + "B-PER(4) \t 10059 2.487%\n", + "B-ORG(3) \t 9322 2.305%\n", + "I-MISC(6) \t 6176 1.527%\n", + "B-MISC(2) \t 5062 1.251%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "counter = collections.Counter()\n", + "total_l = 0\n", + "\n", + "for item in tqdm(itertools.chain(datasetTrain,datasetDev,datasetTest)):\n", + " entities = item[\"entity\"]\n", + " for entity in entities:\n", + " counter[entity] += 1\n", + " total_l += len(entities)\n", + "print(f\"{'token':<12}\\t{'count':>12} {'frequency%':>12}\")\n", + "for token,count in counter.most_common():\n", + " tid = tagIdConverter.convert_tokens_to_ids([token])[0]\n", + " print(f\"{f'{token}({tid})':<12}\\t{count:>12}{count*100/total_l:>12.3f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Training.ipynb b/Training.ipynb index 296c9f4..97deb0f 100644 --- a/Training.ipynb +++ b/Training.ipynb @@ -1862,7 +1862,6 @@ } ], "source": [ - "tagIdConverter = TagIdConverter()\n", "counter = collections.Counter()\n", "total_l = 0\n", "\n",