ner-study/Training.ipynb

1840 lines
112 KiB
Plaintext
Raw Normal View History

2022-02-13 17:34:03 +09:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1cc9cec9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sys\n",
"sys.executable"
]
},
{
"cell_type": "markdown",
"id": "f9c786f1",
"metadata": {},
"source": [
"파이썬 환경 확인.\n",
"envs\\\\nn\\\\python.exe 으로 끝나기를 기대합니다"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4c0a08d7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 9.31it/s]\n"
2022-02-13 17:34:03 +09:00
]
}
],
"source": [
"from time import sleep\n",
"from tqdm import tqdm, trange\n",
"\n",
"lst = [i for i in range(5)]\n",
"\n",
"for element in tqdm(lst):\n",
" sleep(0.1)"
]
},
{
"cell_type": "markdown",
"id": "56a66a44",
"metadata": {},
"source": [
"### tqdm 소개\n",
"tqdm은 다음과 같이 Progress bar 그려주는 라이브러리이에요. 와! 편하다.\n",
"```\n",
"from tqdm.auto import tqdm\n",
"from tqdm.notebook import tqdm\n",
"```\n",
"이건 0에서 멈춰이있고 작동하지 않더라고요. 왜인진 몰라요."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d00b25e6",
"metadata": {},
"outputs": [],
"source": [
"from preprocessing import readPreporcssedDataAll\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from dataset import make_collate_fn, DatasetArray\n",
"from transformers import BertTokenizer\n",
"import torch.nn as nn\n",
"from read_data import TagIdConverter"
]
},
{
"cell_type": "markdown",
"id": "a1d7bb15",
"metadata": {},
"source": [
"대충 필요한 것 임포트"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "433419f4",
"metadata": {},
"outputs": [],
"source": [
"tagIdConverter = TagIdConverter()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "79fb54df",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda available : True\n",
"available device count : 1\n",
"device name: NVIDIA GeForce RTX 3070\n"
]
}
],
"source": [
"print(\"cuda available :\",torch.cuda.is_available())\n",
"print(\"available device count :\",torch.cuda.device_count())\n",
"\n",
"if torch.cuda.is_available():\n",
" device_index = torch.cuda.current_device()\n",
" print(\"device name:\",torch.cuda.get_device_name(device_index))"
]
},
{
"cell_type": "markdown",
"id": "dd22dd5e",
"metadata": {},
"source": [
"cuda가 가능한지 먼저 확인해보아요."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2cd9fe37",
"metadata": {},
"outputs": [],
"source": [
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)\n",
"\n",
"my_collate_fn = make_collate_fn(tokenizer)"
]
},
{
"cell_type": "markdown",
"id": "2177c793",
"metadata": {},
"source": [
"데이터 로딩 준비"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e738062d",
"metadata": {},
"outputs": [],
"source": [
"from transformers import BertModel"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "70296ee3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
2022-02-13 17:34:03 +09:00
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "6cbc236d",
"metadata": {},
"outputs": [],
"source": [
"class MyModel(nn.Module):\n",
" def __init__(self,output_feat: int,bert):\n",
" super().__init__()\n",
" self.bert = bert\n",
" self.dropout = nn.Dropout(p=0.1)\n",
" self.lin = nn.Linear(768,output_feat) #[batch_size,word_size,768] -> [batch_size,word_size,output_feat]\n",
" self.softmax = nn.Softmax(2) #[batch_size,word_size,output_feat] -> [batch_size,word_size,output_feat]\n",
" #0부터 시작해서 2 번째 차원에 softmax.\n",
"\n",
" def forward(self,**kargs):\n",
" emb = self.bert(**kargs)\n",
" e = self.dropout(emb['last_hidden_state'])\n",
" w = self.lin(e)\n",
" return w"
]
},
{
"cell_type": "markdown",
"id": "7dec337a",
"metadata": {},
"source": [
"모델 선언\n",
"`nn.CrossEntropy`는 소프트맥스를 겸함."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e5214de8",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (lin): Linear(in_features=768, out_features=22, bias=True)\n",
" (softmax): Softmax(dim=2)\n",
")\n"
]
}
],
"source": [
"model = MyModel(22,bert)\n",
"model.cuda()\n",
"bert.cuda()\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "72df7ac3",
"metadata": {},
"source": [
"Tag의 종류가 22가지 입니다. 그래서 22을 넣었어요."
]
},
{
"cell_type": "markdown",
"id": "0bba477c",
"metadata": {},
"source": [
"생성과 동시에 gpu로 옮기자.\n",
"`cuda` 저거 호출하면 됨."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5aa129e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"bert current device : cuda:0\n"
]
}
],
"source": [
"print(\"bert current device :\",bert.device)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "61d356b6",
"metadata": {},
"outputs": [],
"source": [
2022-02-18 17:32:13 +09:00
"#for param in bert.parameters():\n",
"# param.requires_grad = False"
2022-02-13 17:34:03 +09:00
]
},
{
"cell_type": "markdown",
"id": "cf6690b2",
"metadata": {},
"source": [
"bert 는 업데이트 하지 않는다. 메모리를 아낄 수 있다."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f28a1a61",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cuda\")\n",
"device"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 28,
2022-02-13 17:34:03 +09:00
"id": "8844beef",
"metadata": {},
"outputs": [],
"source": [
"inputs = {'input_ids': torch.tensor([[ 101, 39671, 8935, 73380, 30842, 9632, 125, 9998, 9251, 9559,\n",
" 9294, 8932, 28143, 9952, 8872, 127, 9489, 34907, 9952, 9279,\n",
" 12424, 102]],device=device), \n",
" 'token_type_ids': torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],device=device), \n",
" 'attention_mask': torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],device=device)}"
]
},
{
"cell_type": "markdown",
"id": "4046945a",
"metadata": {},
"source": [
"적당한 인풋을 정의한다"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 29,
2022-02-13 17:34:03 +09:00
"id": "c37c3c1b",
"metadata": {},
"outputs": [],
"source": [
"emb = model(**inputs)"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 30,
2022-02-13 17:34:03 +09:00
"id": "261d4cc7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 22, 22])"
]
},
2022-02-18 17:32:13 +09:00
"execution_count": 30,
2022-02-13 17:34:03 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.size()"
]
},
{
"cell_type": "markdown",
"id": "d492e37b",
"metadata": {},
"source": [
"결과가 잘 나왔어요."
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 31,
2022-02-13 17:34:03 +09:00
"id": "c773fdba",
"metadata": {},
"outputs": [],
"source": [
"entity_ids = torch.tensor([21,21,7,17,17,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21]\n",
" ,dtype=torch.int64,device=device)"
]
},
{
"cell_type": "markdown",
"id": "b0b69822",
"metadata": {},
"source": [
"잘 이해하지는 못하겠는데, int64면 실행이 되고 int32이면 실행이 안된다.\n",
"```\n",
"RuntimeError: \"nll_loss_forward_reduce_cuda_kernel_2d_index\" not implemented for 'Int'\n",
"```\n",
"이런 오류를 내면서 죽음."
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 32,
2022-02-13 17:34:03 +09:00
"id": "0f97b71a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([22, 22])"
]
},
2022-02-18 17:32:13 +09:00
"execution_count": 32,
2022-02-13 17:34:03 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.view(-1,emb.size(-1)).size()"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 36,
"id": "2a35055b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([19, 13, 21, 21, 19, 21, 13, 13, 16, 16, 21, 16, 19, 13, 2, 13, 13, 16,\n",
" 13, 20, 19, 19], device='cuda:0')"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict = emb.view(-1,emb.size(-1)).argmax(dim=-1)\n",
"predict"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "778c99b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([False, False, False, False, False, True, False, False, False, False,\n",
" True, False, False, False, False, False, False, False, False, False,\n",
" False, False], device='cuda:0')"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(predict == entity_ids)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "798091aa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n",
" device='cuda:0')"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(predict == entity_ids) * inputs[\"attention_mask\"]"
]
},
{
"cell_type": "code",
"execution_count": 19,
2022-02-13 17:34:03 +09:00
"id": "d7d0164a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-02-18 17:32:13 +09:00
"tensor(2.9564, device='cuda:0', grad_fn=<NllLossBackward0>)"
2022-02-13 17:34:03 +09:00
]
},
2022-02-18 17:32:13 +09:00
"execution_count": 19,
2022-02-13 17:34:03 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)(emb.view(-1,emb.size(-1)),entity_ids)"
]
},
{
"cell_type": "markdown",
"id": "9fee44d6",
"metadata": {},
"source": [
"크로스 엔트로피를 계산하는 데에 성공.\n",
"ignore_index는 padding class를 넣어요."
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 40,
2022-02-13 17:34:03 +09:00
"id": "197cfa62",
"metadata": {},
"outputs": [],
"source": [
"del inputs\n",
"del entity_ids\n",
"del emb"
]
},
{
"cell_type": "markdown",
"id": "733cb548",
"metadata": {},
"source": [
"본격적으로 학습시켜봅시다."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "40a3e52c",
"metadata": {},
"outputs": [],
"source": [
"datasetTrain, datasetDev, datasetTest = readPreporcssedDataAll()"
]
},
{
"cell_type": "markdown",
"id": "66d41fee",
"metadata": {},
"source": [
"데이터 셋이 적어도 어느정도 성능이 나와하야 할지 생각해봅시다.\n",
"`O` 토큰으로 범벅이 되있으니 전부 `O`로 찍는 것 보다 좋은 성능이 나와야 하지 않겠어요?\n",
"한번 시도해봅시다."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "80c37a04",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"100%|██████████████████████████████████████████████████████████████████████████| 4250/4250 [00:00<00:00, 236146.99it/s]"
2022-02-13 17:34:03 +09:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"151572/190488 = 0.7957036663726849\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"total_l = 0\n",
"total_o = 0\n",
"\n",
"for item in tqdm(datasetTrain):\n",
" entities = item[\"entity\"]\n",
" l = len(entities)\n",
" o = sum(map(lambda x: 1 if x == \"O\" else 0,entities))\n",
" total_l += l\n",
" total_o += o\n",
"\n",
"print(f\"{total_o}/{total_l} = {total_o/total_l}\")"
]
},
{
"cell_type": "markdown",
"id": "c0b67c41",
"metadata": {},
"source": [
"79% 보다 높아야 해요."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "619b959f",
"metadata": {},
"outputs": [],
"source": [
2022-02-18 17:32:13 +09:00
"BATCH_SIZE = 4\n",
2022-02-13 17:34:03 +09:00
"train_loader = DataLoader(\n",
" DatasetArray(datasetTrain),\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=my_collate_fn\n",
")\n",
"dev_loader = DataLoader(\n",
" DatasetArray(datasetDev),\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=my_collate_fn\n",
")\n",
"test_loader = DataLoader(\n",
" DatasetArray(datasetTest),\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=my_collate_fn\n",
")"
]
},
{
"cell_type": "markdown",
"id": "7d45dd29",
"metadata": {},
"source": [
2022-02-18 17:32:13 +09:00
"BATCH_SIZE 를 4로 잡는다.\n",
"bert paramter를 freeze 안했을땐 batch를 8 정도로 했어요. 그 이상은 메모리가 부족해서 돌아가지 않아요.\n"
2022-02-13 17:34:03 +09:00
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "efd1837a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"({'input_ids': tensor([[ 101, 39671, 8935, 73380, 30842, 9632, 125, 9998, 9251, 9559,\n",
" 9294, 8932, 28143, 9952, 8872, 127, 9489, 34907, 9952, 9279,\n",
" 12424, 102]]),\n",
" 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n",
" 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])},\n",
" tensor([[21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,\n",
" 21, 21, 21, 21]]))"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_collate_fn(datasetTrain[0:1])"
]
},
{
"cell_type": "markdown",
"id": "c4a2e2e1",
"metadata": {},
"source": [
"데이터를 한번 더 확인"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "c575cb55",
"metadata": {},
"outputs": [],
"source": [
"from torch.optim import AdamW"
]
},
{
"cell_type": "markdown",
"id": "627cb2f8",
"metadata": {},
"source": [
"tqdm 확인"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "56844773",
"metadata": {},
"outputs": [],
"source": [
2022-02-18 17:32:13 +09:00
"optimizer = AdamW(model.parameters(), lr=1.0e-5)\n",
2022-02-13 17:34:03 +09:00
"CELoss = nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)"
]
},
{
"cell_type": "markdown",
"id": "eaa08ab2",
"metadata": {},
"source": [
"옵티마이져 준비"
]
},
{
"cell_type": "code",
"execution_count": 27,
2022-02-18 17:32:13 +09:00
"id": "78e46670",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 2, 3, 4]\n",
"[5, 6, 7, 8]\n"
]
}
],
"source": [
"from groupby_index import groupby_index\n",
"\n",
"for g in groupby_index([1,2,3,4,5,6,7,8],4):\n",
" print([*g])"
]
},
{
"cell_type": "markdown",
"id": "ed61ce06",
"metadata": {},
"source": [
"`groupby_index` 그룹으로 묶어서 실행"
]
},
{
"cell_type": "code",
"execution_count": 43,
2022-02-13 17:34:03 +09:00
"id": "109259b4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"Epoch 0: 100%|███████████████████████████████████████| 1063/1063 [00:45<00:00, 23.15batch/s, accuracy=0.923, loss=2.24]\n"
2022-02-13 17:34:03 +09:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"Epoch 1: 100%|███████████████████████████████████████| 1063/1063 [00:46<00:00, 23.07batch/s, accuracy=0.961, loss=1.52]\n"
2022-02-13 17:34:03 +09:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 2 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"Epoch 2: 100%|██████████████████████████████████████| 1063/1063 [00:46<00:00, 23.06batch/s, accuracy=0.976, loss=0.793]\n"
2022-02-13 17:34:03 +09:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 3 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"Epoch 3: 100%|███████████████████████████████████████| 1063/1063 [00:46<00:00, 23.07batch/s, accuracy=0.935, loss=1.88]\n"
2022-02-13 17:34:03 +09:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 4 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"Epoch 4: 100%|███████████████████████████████████████| 1063/1063 [00:46<00:00, 23.09batch/s, accuracy=0.98, loss=0.524]\n"
2022-02-13 17:34:03 +09:00
]
}
],
"source": [
"TRAIN_EPOCH = 5\n",
"\n",
"result = []\n",
"iteration = 0\n",
"\n",
2022-02-18 17:32:13 +09:00
"t = []\n",
"\n",
2022-02-13 17:34:03 +09:00
"model.zero_grad()\n",
"\n",
"for epoch in range(TRAIN_EPOCH):\n",
" model.train()\n",
" print(f\"epoch {epoch} start:\")\n",
2022-02-18 17:32:13 +09:00
" with tqdm(train_loader, unit=\"minibatch\") as tepoch:\n",
" tepoch.set_description(f\"Epoch {epoch}\")\n",
" \n",
" for batch in groupby_index(tepoch,8):\n",
" corrects = 0\n",
" totals = 0\n",
" losses = 0\n",
2022-02-13 17:34:03 +09:00
" \n",
" optimizer.zero_grad()\n",
2022-02-18 17:32:13 +09:00
" for mini_i,mini_l in batch:\n",
" batch_inputs = {k: v.cuda(device) for k, v in list(mini_i.items())}\n",
" batch_labels = mini_l.cuda(device)\n",
" attention_mask = batch_inputs[\"attention_mask\"]\n",
" \n",
" output = model(**batch_inputs)\n",
" loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n",
" \n",
" prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n",
" corrects += ((prediction == batch_labels.view(-1)) * attention_mask.view(-1)).sum().item()\n",
" totals += attention_mask.view(-1).sum().item()\n",
" losses += loss.item()\n",
" loss.backward()\n",
2022-02-13 17:34:03 +09:00
"\n",
" optimizer.step()\n",
2022-02-18 17:32:13 +09:00
" accuracy = corrects / totals\n",
" result.append({\"iter\":iteration,\"loss\":losses,\"accuracy\":accuracy})\n",
" tepoch.set_postfix(loss=losses, accuracy= accuracy)\n",
2022-02-13 17:34:03 +09:00
" iteration += 1"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 44,
2022-02-13 17:34:03 +09:00
"id": "f0d9b2d7",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 45,
2022-02-13 17:34:03 +09:00
"id": "19ca6da1",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 46,
2022-02-13 17:34:03 +09:00
"id": "0bee685c",
"metadata": {},
"outputs": [
{
"data": {
2022-02-18 17:32:13 +09:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD7CAYAAACBiVhwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAABMJElEQVR4nO2dd5gURfPHv3WJOw4k53ggKCAKimAEBUURxYygL4qgiIqv4VVB/b0ChlcwooIYUEGCKIiAgAIqqCBZMkeOR07HHQfc3e7W74/e3p3ZnY2XdqA+zzPPzvT09NTc7XZNVVdXEzNDEARBEIzElbQAgiAIQuwhykEQBEHwQ5SDIAiC4IcoB0EQBMEPUQ6CIAiCH6IcBEEQBD9CKgciqkNE84gonYjWE9HT7vJBRLSXiFa5t1uKXlxBEAShOKBQ8xyIqAaAGsz8DxGVBbACwB0AugI4yczvFrmUgiAIQrGSEKoCM+8HsN+9n01E6QBqRXOzuLg4TklJieZSQRCEc5ZTp04xMxfrMEBIy8FUmag+gD8BXATgOQA9AWQBWA7gP8x8PNj1qampnJOTE62sgiAI5yREdIqZU4vznmFrIiIqA+AHAM8wcxaAkQAaAmgBZVm8F+C6PkS0nIiWOxyOgkssCIIgFDlhWQ5ElAhgBoDZzPy+xfn6AGYw80XB2hHLQRAEIXJi0nIgIgLwJYB0o2JwD1Rr7gSwrvDFEwRBEEqCcNxKVwPoAaC9T9jq20S0lojWALgewLNFKaggCIJgDRF9RUSHiMjyJZ0UHxHRViJaQ0SXhmoznGilBQDI4tSs0CILgiAIxcBoAMMBfBPgfCcAjdxbG6gx4zbBGpQZ0oIgCDaHmf8EcCxIldsBfMOKxQDK+wwN+CHKQRAE4eynFoA9huMMhJivZgvlMGPzDAxZMKSkxRCEs5MzZwCJIgwPlwsIJyT/11+BTZsK884JekqAe+sT4fVWQwNBQ1VtoRx+2foL3v1bsnQIQpFw4YVAmTLRXbt0KXDypH95fn5whcMMbN4MHDkCLFgQ3b0BwOkEHn4Y+PZbYNUqICPD+l7vvQcQAVlZwOnTqiwSsrOBkSOBfv2AxETv9bm5wF9/+de/8Ub1dy08HMzcyrB9HuH1GQDqGI5rA9gX7AJbKIeEuAQ4XDKBToiA7Gz1RmwnJk1SHVyojuvAAWDo0Mg7uEDs2hX8fIUKwGOP+ZefPAm0aQN06uR/7pZbzArH5QI6dwZGjQI6dADuuw+44AKgShXg2mv9FclrrwF16nifccMGoFcv4K67gN27vfWmTwdGjwbuvx9o2VJds3SpOnf8uLp+9Gjg+edV2VtvAaVLA999p46ZgcaNgRdfBObP97bLDPToododMwa49FLgiSeUggCUvHv2AM89B7RtC6Snq3t06WJ+FiJg/Pjgf9/iYTqAB91RS1cAOOFOjRQYZi62rXTp0hwN/5n9H059MzWqawUbk5/PfO+9zEuXessmTmTeuDH0tQBzixZFJ1ukbN7M/Pff5jKXi/ngQeZTp5jHjFEyA8xvvsl85ox/GydOMB8+zNyxo6r3ww/Mp097z8+dy3zBBao9TU4Oc26utUzDhjGvWuW9ry+ZmcyPPeZ//tgx5uefV/8Xq2sPHvSWr1un5D5xwltmtS1YwPzGG8zvvMP8zz/e8u3bmT/7zFz3nntUuxs3Mvfr59/WhAnM+/er/fbtzedq1PC28f776m9gPD9vnpJ19Ojg8t5+e+Bz6en+ZY88EuzbERIAORykbwXwLVSminwoK6E3gL4A+rrPE4ARALYBWAugVbD22P1fjXnl0H9ufy71eqmorhVszLZt6itap463LFBH5ks49Q4cYH7ySea9e5n/+kt9Gunbl/nTTwNff+yY/zXMSqmNHGnulK3kefllVVahgnUn8/DDSgEwMzscqmOrVo25aVNvnSeeYL74Yubrr/eWzZ9v7vRbtVJtLFnC/O676njHDv/7MTN/+CFzo0bMb71l3emdOaP+ZgBz587ecy+9pBRXXp71s7zxRvDO1thJt2rl3a9b17/uPfeoz+Rk5nbt/M+XLcvcqVPw+9WrF905vdWpoz6bNfM/d+GF1tccPBj8+xiEUMqhKLZivVm0yuGV317huMFxUV0r2IDPPlOdmS/r1nl/WOvWqbJwOv2TJ731jG/gx48zz5yp3grvvJO5dWtV5+671WfDhuZ2dBvGTj49XXXaeXnMNWuypxN//nl13vjW/Prr6u174UJrubUFEGpjZl6/Pry6APPgwf5lRmUBMPfu7V9nwYLQbffp47UmdAept549mZcvD1/OYFv79qqTD1WvfPng5594wro8NbXgMt5/v/rfWP29rbapU4N/b4MgyiEAA+cNZAwCO13OqK4XfJg+PbCrIRCHD6u3YZdLHZ88yfzf/zKvWFEwWfSbJpHqWLdtU+VPP8383HPmH5fT6d1nVrI4HP5taosDYF60yN91EGzbulW9kX7yibfM+Man3RS//+5/rcvF/Ntv3uMePZivuspcx+jyMb7tB9uYzQom1FapUug6l18efnuRbFrhhto+/jjwuTp11P91wgRz+YAB1vUnTmR+8cXAf7v9+5k3bPCWNWgQ2TNdeCHz2LH+5f/5j2p/3jz/cxMmMGdleY/HjSvQz0SUQwDe+OMNxiBwniMvquvPaebNUx0Ws+qIZ8xQ//Znnw3veodD+bK1+b55syp/6SXvF3/dOtUxfvYZc3a2Ov/tt8wtW6oOnVm5Wnbv9rZ79Kj6tDLLjx+3/pG+8or5R3/ttcy1a6s3uE8+UfdwOpVvX9cbOtT/Dde49erl3S9bVv3gfev8+adq29jB+HZcgHq+JUu8x0a3i94GDWK+6CKlbC++OHTHVLeuslZmzoysQwu1JSeH/nsU1mb1P7ayUsqVU2MIO3Z4vyfG587NZa5YUe3fcou3PDNT1f3zT+YbbjC3afwe67IePcx1brhBvf0nJlrL36SJ9TjCu++qto0WLsCckKC+w8zm5y0AohwCMOSvIYxB4FN5p0JXPts5fly9CQXim2/Uv/XYMXVs/JF07eo9bthQvf2H4u23zV/8RYtUed++3rK5c5l//lntP/WU+b7aJ3/FFep43jzmF15Q+6H8wsG2UaP8yypWVAOyRh/+rbda+631pv9egHqjbNTIup7vW+v77/vXOf98s0Whn9lqe+AB6/Jy5azLfQdlly0zH1uNEQTq6Fq0CHze5WLu3z/6/4vVZuU+27rVv+zbb/2/f/v2mWXbtEl91/TLyauvmuuvXWtu04gu+/RT737lyl4r2ugeqlLFu9+0KfPOnf7yjh2rrjtwwFzesqX/Pf/5J/RvLQgloRxsE8oKQMJZ33lHhRXWCDLr/aOP1Ofmzf7nvv/eu79tmwo1XBcime7PP5uPMzJUqJ5xIlBcHLDfHRX355/me2/bpj4XL1af11+vnsOq7Uh45BH/smPH1MSj48e995oxwxz66MvFF3v38/OBLVus6w3xmYRp1ebWrcCUKd5j/cxWBApvTEuzLt+503zcqpX5uEMH6+v++1/zcYUKKiwzEERA5crmsqws7/cqGlq29C+rVs18/McfQNeu/vVSDVmqiVTY6c03A6VKqbIGDcz1zzsvsByLFwMTJwKNGnnLPv8cSEpS+8bQ22XLVOitpnRp//YqVlSflSqZy5OTgz+HXShOTRSt5TBs0TDGIPCxU8eiuv6s4Ngx/zeirCzmtm2ZJ09W4Y+XXcbcuLE6//vvqo6uv2uX9Vvdl19a32/oUOZ//Uv5W+++2xuhYrX9+ivzBx8EPp+ZWbA3zy1bwqtnfNsDzG6oQNvBg4HdCeFut90W/bVt25qPje4S41a1qv//33hsfMPW29tvq3pvvsn844+qbMoU6/BPQFlOzP7jAfn5qnzJEv/Q0H//23x8551mK+bdd1W4bXy8uZ7L5d3fsCHw9z4/3/zMmpwcJafvmJPRJdm2rXWbxt+CHuNiNlsUzMxz5qj9Zs2Uu9T372UMTc7MVL9DQI0laXTdPXsCP2MYQNxK1gxfMpwxCHzwZPShYDGPHujNz1c/Zt/oHV+z1uViHj7
2022-02-13 17:34:03 +09:00
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"iters = [item[\"iter\"] for item in result]\n",
"fig, ax1 = plt.subplots()\n",
"ax1.plot(iters,[item[\"loss\"] for item in result],'g')\n",
"ax2 = ax1.twinx()\n",
"ax2.plot(iters,[item[\"accuracy\"] for item in result],'r')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "2bb67740",
"metadata": {},
"source": [
"학습 그래프입니다."
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 47,
2022-02-13 17:34:03 +09:00
"id": "0defca72",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 48,
2022-02-13 17:34:03 +09:00
"id": "2f45cae0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"gpu allocated : 2739 MB\n",
"gpu reserved : 2910MB\n"
2022-02-13 17:34:03 +09:00
]
}
],
"source": [
"print(f\"gpu allocated : {torch.cuda.memory_allocated() // 1024**2} MB\")\n",
"print(f\"gpu reserved : {torch.cuda.memory_reserved() // 1024 ** 2}MB\")"
]
},
{
"cell_type": "markdown",
"id": "68f18f4d",
"metadata": {},
"source": [
"gpu 메모리 사용량을 보는 코드"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 49,
2022-02-13 17:34:03 +09:00
"id": "73b01630",
"metadata": {},
"outputs": [],
"source": [
"del batch_inputs\n",
"del batch_labels\n",
"del loss\n",
"del optimizer"
]
},
{
"cell_type": "markdown",
"id": "02a3f367",
"metadata": {},
"source": [
"한번 테스트 해보자"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 50,
2022-02-13 17:34:03 +09:00
"id": "af93a3ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"김광현 을 둘러싼 주변 의 우려 는 너무 많이 던진다는 것 .\n",
"두산 은 주포 인 김동주 와 안경현 이 빠진 상황 이 라 타력 에 적 지 않 은 문제점 을 안 고 있 지만 29 일 잠실 롯데전 이 우천 으로 취소 되 는 바람 에 시간 을 벌 었 고 더구나 삼성 이 4 연패 로 2 위 로 추락 해 어부지리 로 1 위 로 올라서 는 행운 을 잡 았 다 .\n"
]
}
],
"source": [
"for data in datasetTrain[100:102]:\n",
" print(tokenizer.convert_tokens_to_string(data[\"tokens\"]))"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 51,
2022-02-13 17:34:03 +09:00
"id": "9b2cd5c4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[ 101, 8935, 118649, 30842, 9633, 9105, 30873, 119091, 9689,\n",
" 118985, 9637, 9604, 26737, 9043, 9004, 32537, 47058, 9076,\n",
" 65096, 11018, 8870, 119, 102, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0],\n",
" [ 101, 9102, 21386, 9632, 9689, 55530, 9640, 8935, 18778,\n",
" 16323, 9590, 9521, 31720, 30842, 9638, 9388, 18623, 9414,\n",
" 65649, 9638, 9157, 9845, 28143, 9559, 9664, 9706, 9523,\n",
" 9632, 9297, 17730, 34907, 9633, 9521, 8888, 9647, 9706,\n",
" 19105, 10386, 9641, 9655, 31503, 9208, 28911, 16617, 9638,\n",
" 9604, 38631, 29805, 9773, 22333, 9098, 9043, 9318, 61250,\n",
" 9559, 9485, 18784, 9633, 9339, 9557, 8888, 9074, 17196,\n",
" 16439, 9410, 17138, 9638, 125, 9568, 119383, 9202, 123,\n",
" 9619, 9202, 9765, 107693, 9960, 9546, 14646, 12508, 12692,\n",
" 9202, 122, 9619, 9202, 9583, 17342, 12424, 9043, 9966,\n",
" 21614, 9633, 9656, 9529, 9056, 119, 102]],\n",
" device='cuda:0'),\n",
" 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0],\n",
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1]], device='cuda:0'),\n",
" 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0]], device='cuda:0')}"
]
},
2022-02-18 17:32:13 +09:00
"execution_count": 51,
2022-02-13 17:34:03 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs, labels = my_collate_fn(datasetTrain[100:102])\n",
"inputs = {k:v.to(device) for k,v in inputs.items()}\n",
"inputs"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 52,
2022-02-13 17:34:03 +09:00
"id": "6b31782c",
"metadata": {},
"outputs": [],
"source": [
"model.eval()\n",
"with torch.no_grad():\n",
" predict_label = model(**inputs)\n",
" sp = model.softmax(predict_label)\n",
" p = sp.argmax(dim=-1,keepdim=True).squeeze().cpu()"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 53,
2022-02-13 17:34:03 +09:00
"id": "7f4d43ce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"['O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n",
"['O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DT', 'I-DT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n"
2022-02-13 17:34:03 +09:00
]
}
],
"source": [
"for data in p.numpy():\n",
" print(tagIdConverter.convert_ids_to_tokens(data))"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 54,
2022-02-13 17:34:03 +09:00
"id": "5ade3317",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n",
"['O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DT', 'I-DT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n"
]
}
],
"source": [
"for data in labels.squeeze().numpy():\n",
" print(tagIdConverter.convert_ids_to_tokens(data))"
]
},
{
"cell_type": "markdown",
"id": "b48e22ff",
"metadata": {},
"source": [
"It work!"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 55,
2022-02-13 17:34:03 +09:00
"id": "383dd24a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([194, 1])"
]
},
2022-02-18 17:32:13 +09:00
"execution_count": 55,
2022-02-13 17:34:03 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sp.cpu().view(-1,sp.size(-1)).argmax(dim=-1,keepdim=True).size()"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 56,
2022-02-13 17:34:03 +09:00
"id": "ff74fced",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-02-18 17:32:13 +09:00
"tensor(120)"
2022-02-13 17:34:03 +09:00
]
},
2022-02-18 17:32:13 +09:00
"execution_count": 56,
2022-02-13 17:34:03 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"correct = (sp.cpu().view(-1,sp.size(-1)).argmax(dim=-1) == labels.view(-1)).sum()\n",
"correct"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 57,
2022-02-13 17:34:03 +09:00
"id": "3f6ad5d8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(120, device='cuda:0')"
]
},
2022-02-18 17:32:13 +09:00
"execution_count": 57,
2022-02-13 17:34:03 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[\"attention_mask\"].view(-1).sum()"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 58,
2022-02-13 17:34:03 +09:00
"id": "986fd52b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-02-18 17:32:13 +09:00
"tensor(1., device='cuda:0')"
2022-02-13 17:34:03 +09:00
]
},
2022-02-18 17:32:13 +09:00
"execution_count": 58,
2022-02-13 17:34:03 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy = correct / inputs[\"attention_mask\"].view(-1).sum()\n",
"accuracy"
]
},
{
"cell_type": "markdown",
"id": "8cc47437",
"metadata": {},
"source": [
"accuracy는 다음과 같이 구해져요."
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 59,
2022-02-13 17:34:03 +09:00
"id": "1f3f8666",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"100%|█████████████████████████████████████████████████████████████████████████████| 125/125 [00:01<00:00, 74.40batch/s]\n"
2022-02-13 17:34:03 +09:00
]
}
],
"source": [
"model.eval()\n",
"collect_list = []\n",
"with torch.no_grad():\n",
" with tqdm(test_loader, unit=\"batch\") as tepoch:\n",
" for batch_i,batch_l in tepoch:\n",
" batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n",
" batch_labels = batch_l.cuda(device)\n",
" output = model(**batch_inputs)\n",
" loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n",
" \n",
" prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n",
" correct = (prediction == batch_labels.view(-1)).sum().item()\n",
" accuracy = correct / batch_inputs[\"attention_mask\"].view(-1).sum()\n",
" \n",
" collect_list.append({\"loss\":loss.item(),\"accuracy\":accuracy, \"batch_size\":batch_labels.size(0),\n",
" \"predict\":output.argmax(dim=-1).cpu(),\n",
" \"actual\":batch_labels.cpu(),\n",
" \"attention_mask\":batch_inputs[\"attention_mask\"].cpu()})"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 60,
2022-02-13 17:34:03 +09:00
"id": "b7567f48",
"metadata": {},
"outputs": [],
"source": [
"def getConfusionMatrix(predict,actual,attention_mask):\n",
" ret = torch.zeros((22,22),dtype=torch.long)\n",
" for i,(p_s,a_s) in enumerate(zip(predict,actual)):\n",
" for j,(p,a) in enumerate(zip(p_s,a_s)):\n",
" ret[p,a] += attention_mask[i,j]\n",
" return ret"
]
},
{
"cell_type": "markdown",
"id": "a898cd34",
"metadata": {},
"source": [
"단순하게 confusion matrix를 계산하는 함수. 클래스 22개 정해져 있음."
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 61,
2022-02-13 17:34:03 +09:00
"id": "15cd73a5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])"
]
},
2022-02-18 17:32:13 +09:00
"execution_count": 61,
2022-02-13 17:34:03 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"getConfusionMatrix(torch.tensor([[0,1]]),torch.tensor([[1,1]]),torch.tensor([[1,1]]))"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 62,
2022-02-13 17:34:03 +09:00
"id": "de9c7932",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-02-18 17:32:13 +09:00
"average_loss : 0.16166621172241866, average_accuracy : 0.9605389833450317, size :500\n"
2022-02-13 17:34:03 +09:00
]
}
],
"source": [
"total_loss = 0\n",
"total_accuracy = 0\n",
"total_size = 0\n",
"confusion = torch.zeros((22,22),dtype=torch.long)\n",
"\n",
"for item in collect_list:\n",
" batch_size = item[\"batch_size\"]\n",
" total_loss += batch_size * item[\"loss\"]\n",
" total_accuracy += batch_size * item[\"accuracy\"]\n",
" total_size += batch_size\n",
" confusion += getConfusionMatrix(item[\"predict\"],item[\"actual\"],item[\"attention_mask\"])\n",
"print(f\"\"\"average_loss : {total_loss/total_size}, average_accuracy : {total_accuracy/total_size}, size :{total_size}\"\"\")"
]
},
{
"cell_type": "markdown",
"id": "24539b98",
"metadata": {},
"source": [
2022-02-18 17:32:13 +09:00
"test로 보면 결과가 나왔어요. 96% 나와요. F1 스코어는 아직입니다."
2022-02-13 17:34:03 +09:00
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 73,
2022-02-13 17:34:03 +09:00
"id": "f6047991",
"metadata": {},
2022-02-18 17:32:13 +09:00
"outputs": [],
"source": [
"confusion = confusion[0:21][0:21]"
]
},
{
"cell_type": "markdown",
"id": "4830938c",
"metadata": {},
2022-02-13 17:34:03 +09:00
"source": [
2022-02-18 17:32:13 +09:00
"Outside 토큰에 해당하는 곳을 짜르겠습니다."
2022-02-13 17:34:03 +09:00
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 74,
2022-02-13 17:34:03 +09:00
"id": "000d1e68",
"metadata": {},
"outputs": [
{
"data": {
2022-02-18 17:32:13 +09:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAEICAYAAADWe9ZcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAABNf0lEQVR4nO2dd3wVVfrGv28SivQOSSBACqRXerOggKhrRUAFV7Htqj91XV1su9ZVV9e1YFksawexIwLSBEGUKiC9Sa8JCRASUt/fH3cSLyFlktybO3eYh8/53CnvnOe8c29eZs57nnNEVXHgwIEDB5UjwNcNcODAgQN/gBMsHThw4MAEnGDpwIEDBybgBEsHDhw4MAEnWDpw4MCBCTjB0oEDBw5MwAmWZwjEhf+JSKaILK1FPQNFZJMn2+YriEiYiGSLSKCv2+LA+hBnnOWZAREZCEwCuqvqCV+3x9sQkR3ATao6x9dtcWAPOE+WZw46AzvOhEBpBiIS5Os2OPAvOMHSghCRTiLyhYgcFpEMEZlgHA8QkYdFZKeIHBKR90WkuXGui4ioiFwvIrtEJF1EHjLOjQPeAvoar52PicgfRWRRGV4VkUhje7iIrBeR4yKyV0T+ahw/R0T2uF0TIyLzRSRLRNaJyB/czr0rIq+KyLdGPUtEJKICn0vaf4OI7Da6C24TkZ4issaof4KbfYSIzDPuT7qIfCQiLYxzHwBhwDeGv/e71T9ORHYB89yOBYlIKxHZIyKXGHU0EZGtIjK2tt+nA5tAVZ1ioQIEAquB/wCNgYbAAOPcjcBWIBxoAnwBfGCc6wIo8CZwFpAE5AExxvk/AovceE7ZN44pEGls7wcGGtstgVRj+xxgj7Fdz2jPg0B94DzgOK5XfYB3gSNALyAI+AiYXIHfJe1/w/B5CHAS+ApoB4QCh4CzDftI4AKgAdAW+AF40a2+HcD55dT/vnFfz3I7FmTYDAEOGHxvAp/5+vfgFOsU58nSeugFhAD3qeoJVT2pqiVPgNcCL6jqdlXNBh4ARpV5pXxMVXNVdTWuoJtUw3YUALEi0kxVM1V1ZTk2fXAF7WdUNV9V5wHTgNFuNl+o6lJVLcQVLJOr4H3C8HkWcAKYpKqHVHUvsBBIAVDVrao6W1XzVPUw8AJwtgm/HjXua27ZEwbnp8Bc4CLgVhP1OThD4ARL66ETsNMILmURAux029+J64mtvduxA27bObiCWU1wJTAc2CkiC0SkbwXt2a2qxWXaFFqL9hx0284tZ78JgIi0E5HJRhfBMeBDoE0VdQPsruL8RCAe+J+qZpioz8EZAidYWg+7gbAKEhD7cCVqShAGFHJqQDGLE0Cjkh0R6eB+UlWXqeqluF5JvwKmVNCeTiLi/jsKA/bWoD3VxdO4XqETVbUZcB0gbucrGuZR4fAPYwjRf3G9qv+ppP/WgQNwgqUVsRRXf+EzItJYRBqKSH/j3CTgHhHpKiJNgH8Cn1TwFFoVVgNxIpIsIg2BR0tOiEh9EblWRJqragFwDCgqp44luILu/SJST0TOAS4BJtegPdVFUyAbyBKRUOC+MucP4urbrQ4eND5vBJ4H3nfGYDoogRMsLQZVLcIVcCKBXcAeYKRx+h3gA1zJjN9wJUDurCHPZuBxYA6wBVhUxmQMsMN4xb0N15Nb2TrygT8AFwLpwGvAWFXdWJM2VROPAanAUeBbXMkudzwNPGxk0f9aVWUikgb8BVf7i4BncT2Fjvdoqx34LZxB6Q4cOHBgAs6TpQMHDhyYgBMsHThw4MAELB8sRWSYiGwy1BSV9h+ZtfW0ncNtD267+eMN7jMavh4VX1nBpWbZhiurWR9XBje2NraetnO47cFtN3+8wX2mF0smeNq0aaOdO3chOzub/fv3ExUVBcCBA/sB6NAh+LRrzNp62s7htge33fzxFPfOnTtIT0+X0y6qBgKbdVYtPE0wVS409/B3qjqsNnxeg6+jdXklNTVNcwtUP5r8qf7xhnGaW6CaW6D69v/e11v/dHvpvnsxa+tpO4fbHtx288dT3KmpaVrrJ7Kz2mnDlDtNFWC5r+NPRaVWfZZV9XOICy8b59eISGp1A3k5ddbK1tN2Drc9uL1Rp924awwBRMwVC6PGwdJQNryKa0ByLDBaRGLLmF0IRBnlFuD16nCEhnZkz57fpbx79+4hJCSkVraetnO47cFtN3+8wV0rSIC5YmXU9JEU6At857b/APBAGZv/AqPd9jcBwWZfw4/nFmiXrl11w+btevREniYkJOqKVWvLfZUwa+tpO4fbHtx288dT3B55DW/UThv2uMdUwcKv4bWZLTqUU2dw2QP0NmETikv7XCWCgoL4z0sTuOSioRQVFXH9H28kNi6uVraetnO47cFtN3+8wV1zCAT4v8S+xtlwERkBDFXVm4z9MUAvVb3TzeZb4Gk15mMUkbnA/aq6opz6bsH1qk6nsLC0zdt2ljVx4MBBHaN/7x6sWLG8Vp2JAU06aIP4603ZnlzyrxWq2qM2fN5CbToJ9uCae7EEHXFN2VVdGwBUdaKq9lDVHm3btK1Fsxw4cGAtmEzu2DXBAywDosQ1XVh9YBQwtYzNVGCskRXvAxxVVVOv4A4cOLARbJDgqXHr1DWH4h3Ad8AGYIqqrhPXIlO3GWbTge241ml5E/hzdXlmfTeTxLjuxEVH8ty/nvGIraftHG57cNvNH29w1xg2eLL0eYapsmx49slC7Roerus3bSvN0q1cva7cjJ5ZW0/bOdz24LabP57i9kg2vHGwNuz/sKmChbPhln7uXbZ0KRERkXQND6d+/fqMGDmKad98XStbT9s53Pbgtps/3uCuMQRXNtxMsTAsHSz37dtLx46/54dCQzuyd2/5y7uYtfW0ncNtD267+eMN7ppDzuw+y7pAecOa/Fku5nBbl9sbddqNu1YIEHOlCojIOyJySETWuh1rJSKzRWSL8dnS7dwDhtx6k4gMdTueJiK/GudeFhMOWzpY2k0u5nBbl9tu/lhK7ih48snyXaDsrETjgbmqGoVrzffxAIb8ehQQZ1zzmvy+AN3ruMZ1l8ixq57pyNedppUleOwgF3O4/YPbbv5YSu7YNEQbnveUqYKJBA/QBVjrtl8qowaCgU3G9ikSbFwjd/oaNhvdjo8G/lsVb23kjl6H3eRiDrd1ue3mjx/LHduIyHK3/YmqOrGKa9qrMX5bVfeLSDvjeCjws5tdidy6wNgue7xSWHLy37S0HvrjkuVVGzpw4MCr8IjcsVlHbdDnLlO2J2ffX6XcUUS6ANNUNd7Yz1LVFm7nM1W1pYi8Cvykqh8ax9/GNfZ7Fy4Z9vnG8YG4ZNiXVOqHKQ8cOHDgoKYwOyC95kmlgyIS7KKSYOCQcbwiufUeY7vs8Uph+WBpNwWEw21dbrv5Yy0Fj1eHDk0FSmbquB742u34KBFpICJdcSVylhqv7MdFpI+RBR/rdk3F8HUyp7IEjx0UEA63f3DbzR9LKXiaddSGw14wVagiwQNMwjXFY0m/4zigNa4s+Bbjs5Wb/UO4FmPbBFzodrwHsNY4NwGjS7KyYuknS7spIBxu63LbzR9LKXg8OChdVUerarCq1lPVjqr6tqpmqOpgVY0yPo+42T+lqhGq2l1VZ7gdX66q8ca5O9RE8sbSwdJuCgiH27rcdvPHUgoem8gdLT10qLxg788KCIfbutzeqNNu3DWHWF7KaAaWDpZ2U0A43Nbltps/llLwgPWnXzMDXydzKkvw2EEB4XD7B7fd/LGUgqd5mDb8wxumChaeos3ST5Z2U0A43Nbltps/1lLwYIsnS0fB48CBgwrhEQVPyy7a4JyHTdme/Opm+y1YJiKdROR7EdkgIutE5DQ9k4icIyJHRWSVUf5eu+Y6cODAHyEBAaaKlVGb1hUC96pqDNAHuN2YEqksFqpqslEery6J3RQQDrd1ue3mj1UUPIIru26mWBqe6vzEJRe6oMyxc3AJ3muU4LGDAsLh9g9uu/ljJQVPQMvO2uiqd0wVLJzg8chzrzELSAqwpJzTfUVktYjMEJEKe41F5BYRWS4iyw+nHwbsp4BwuK3LbTd/rKbgscOTZa2
2022-02-13 17:34:03 +09:00
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import itertools\n",
"\n",
"plt.title(\"confusion matrix\")\n",
2022-02-18 17:32:13 +09:00
"plt.imshow(confusion,cmap='Blues')\n",
2022-02-13 17:34:03 +09:00
"\n",
"plt.colorbar()\n",
"for i,j in itertools.product(range(confusion.shape[0]),range(confusion.shape[1])):\n",
" plt.text(j,i,\"{:}\".format(confusion[i,j]),horizontalalignment=\"center\",color=\"black\" if i == j else \"black\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "543ede4b",
"metadata": {},
"source": [
"혼동행렬을 보면 별로인것 같습니다.\n",
"왼쪽이 predict이고 밑이 actual입니다."
]
},
{
"cell_type": "markdown",
"id": "9e85493d",
"metadata": {},
"source": [
"$$Precision(class = a) = \\frac{TP(class = a)}{TP(class = a)+FP(class = a)}$$\n",
"\n",
"$$Recall(class = a) = \\frac{TP(class = a)}{TP(class = a)+FN(class = a)}$$\n",
"\n",
"$$F1Score(class = a) = \\frac{2}{\\frac{1}{Precision(class = a)}+\\frac{1}{Recall(class = a)} }$$"
]
},
{
"cell_type": "markdown",
"id": "200c95b9",
"metadata": {},
"source": [
"F1Score는 다음과 같이 주어집니다."
]
},
{
"cell_type": "markdown",
"id": "0b23e7d5",
"metadata": {},
"source": [
"다른 클래스에 대해서도 모두 해보자"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 75,
2022-02-13 17:34:03 +09:00
"id": "38b3eee6",
"metadata": {},
"outputs": [],
"source": [
"def getF1Score(confusion,c):\n",
" TP = confusion[c,c]\n",
" FP = confusion[c].sum() - TP\n",
" FN = confusion[:,c].sum() - TP\n",
" precision = TP / (TP + FP)\n",
" recall = TP / (TP + FN)\n",
"\n",
" f1Score = (2*precision*recall)/(precision + recall)\n",
" return f1Score"
]
},
{
"cell_type": "code",
2022-02-18 17:32:13 +09:00
"execution_count": 77,
2022-02-13 17:34:03 +09:00
"id": "61fe2d6c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"class 0 f1 score : nan\n",
"class 1 f1 score : nan\n",
"class 2 f1 score : nan\n",
"class 3 f1 score : nan\n",
2022-02-18 17:32:13 +09:00
"class 4 f1 score : 0.9583332538604736\n",
"class 5 f1 score : 0.9216590523719788\n",
"class 6 f1 score : 0.9232480525970459\n",
"class 7 f1 score : 0.9203747510910034\n",
"class 8 f1 score : 0.8780487179756165\n",
2022-02-13 17:34:03 +09:00
"class 9 f1 score : nan\n",
"class 10 f1 score : nan\n",
"class 11 f1 score : nan\n",
"class 12 f1 score : nan\n",
"class 13 f1 score : nan\n",
2022-02-18 17:32:13 +09:00
"class 14 f1 score : 0.9240506887435913\n",
"class 15 f1 score : 0.7439999580383301\n",
"class 16 f1 score : 0.885114848613739\n",
"class 17 f1 score : 0.9293712377548218\n",
"class 18 f1 score : 0.932692289352417\n",
2022-02-13 17:34:03 +09:00
"class 19 f1 score : nan\n",
2022-02-18 17:32:13 +09:00
"class 20 f1 score : nan\n"
2022-02-13 17:34:03 +09:00
]
}
],
"source": [
2022-02-18 17:32:13 +09:00
"for i in range(21):\n",
2022-02-13 17:34:03 +09:00
" f1 = getF1Score(confusion,i)\n",
" print(f\"class {i} f1 score : {f1}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b9b55e7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}