ner-study/Training.ipynb

1914 lines
113 KiB
Plaintext
Raw Normal View History

2022-02-13 17:34:03 +09:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1cc9cec9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sys\n",
"sys.executable"
]
},
{
"cell_type": "markdown",
"id": "f9c786f1",
"metadata": {},
"source": [
"파이썬 환경 확인.\n",
"envs\\\\nn\\\\python.exe 으로 끝나기를 기대합니다"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4c0a08d7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 9.21it/s]\n"
]
}
],
"source": [
"from time import sleep\n",
"from tqdm import tqdm, trange\n",
"\n",
"lst = [i for i in range(5)]\n",
"\n",
"for element in tqdm(lst):\n",
" sleep(0.1)"
]
},
{
"cell_type": "markdown",
"id": "56a66a44",
"metadata": {},
"source": [
"### tqdm 소개\n",
"tqdm은 다음과 같이 Progress bar 그려주는 라이브러리이에요. 와! 편하다.\n",
"```\n",
"from tqdm.auto import tqdm\n",
"from tqdm.notebook import tqdm\n",
"```\n",
"이건 0에서 멈춰이있고 작동하지 않더라고요. 왜인진 몰라요."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d00b25e6",
"metadata": {},
"outputs": [],
"source": [
"from preprocessing import readPreporcssedDataAll\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from dataset import make_collate_fn, DatasetArray\n",
"from transformers import BertTokenizer\n",
"import torch.nn as nn\n",
"from read_data import TagIdConverter"
]
},
{
"cell_type": "markdown",
"id": "a1d7bb15",
"metadata": {},
"source": [
"대충 필요한 것 임포트"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "433419f4",
"metadata": {},
"outputs": [],
"source": [
"tagIdConverter = TagIdConverter()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "79fb54df",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda available : True\n",
"available device count : 1\n",
"device name: NVIDIA GeForce RTX 3070\n"
]
}
],
"source": [
"print(\"cuda available :\",torch.cuda.is_available())\n",
"print(\"available device count :\",torch.cuda.device_count())\n",
"\n",
"if torch.cuda.is_available():\n",
" device_index = torch.cuda.current_device()\n",
" print(\"device name:\",torch.cuda.get_device_name(device_index))"
]
},
{
"cell_type": "markdown",
"id": "dd22dd5e",
"metadata": {},
"source": [
"cuda가 가능한지 먼저 확인해보아요."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2cd9fe37",
"metadata": {},
"outputs": [],
"source": [
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)\n",
"\n",
"my_collate_fn = make_collate_fn(tokenizer)"
]
},
{
"cell_type": "markdown",
"id": "2177c793",
"metadata": {},
"source": [
"데이터 로딩 준비"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e738062d",
"metadata": {},
"outputs": [],
"source": [
"from transformers import BertModel"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "70296ee3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "6cbc236d",
"metadata": {},
"outputs": [],
"source": [
"class MyModel(nn.Module):\n",
" def __init__(self,output_feat: int,bert):\n",
" super().__init__()\n",
" self.bert = bert\n",
" self.dropout = nn.Dropout(p=0.1)\n",
" self.lin = nn.Linear(768,output_feat) #[batch_size,word_size,768] -> [batch_size,word_size,output_feat]\n",
" self.softmax = nn.Softmax(2) #[batch_size,word_size,output_feat] -> [batch_size,word_size,output_feat]\n",
" #0부터 시작해서 2 번째 차원에 softmax.\n",
"\n",
" def forward(self,**kargs):\n",
" emb = self.bert(**kargs)\n",
" e = self.dropout(emb['last_hidden_state'])\n",
" w = self.lin(e)\n",
" return w"
]
},
{
"cell_type": "markdown",
"id": "7dec337a",
"metadata": {},
"source": [
"모델 선언\n",
"`nn.CrossEntropy`는 소프트맥스를 겸함."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e5214de8",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (lin): Linear(in_features=768, out_features=22, bias=True)\n",
" (softmax): Softmax(dim=2)\n",
")\n"
]
}
],
"source": [
"model = MyModel(22,bert)\n",
"model.cuda()\n",
"bert.cuda()\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "72df7ac3",
"metadata": {},
"source": [
"Tag의 종류가 22가지 입니다. 그래서 22을 넣었어요."
]
},
{
"cell_type": "markdown",
"id": "0bba477c",
"metadata": {},
"source": [
"생성과 동시에 gpu로 옮기자.\n",
"`cuda` 저거 호출하면 됨."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5aa129e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"bert current device : cuda:0\n"
]
}
],
"source": [
"print(\"bert current device :\",bert.device)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "61d356b6",
"metadata": {},
"outputs": [],
"source": [
"for param in bert.parameters():\n",
" param.requires_grad = False"
]
},
{
"cell_type": "markdown",
"id": "cf6690b2",
"metadata": {},
"source": [
"bert 는 업데이트 하지 않는다. 메모리를 아낄 수 있다."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f28a1a61",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cuda\")\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "8844beef",
"metadata": {},
"outputs": [],
"source": [
"inputs = {'input_ids': torch.tensor([[ 101, 39671, 8935, 73380, 30842, 9632, 125, 9998, 9251, 9559,\n",
" 9294, 8932, 28143, 9952, 8872, 127, 9489, 34907, 9952, 9279,\n",
" 12424, 102]],device=device), \n",
" 'token_type_ids': torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],device=device), \n",
" 'attention_mask': torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],device=device)}"
]
},
{
"cell_type": "markdown",
"id": "4046945a",
"metadata": {},
"source": [
"적당한 인풋을 정의한다"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "c37c3c1b",
"metadata": {},
"outputs": [],
"source": [
"emb = model(**inputs)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "261d4cc7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 22, 22])"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.size()"
]
},
{
"cell_type": "markdown",
"id": "d492e37b",
"metadata": {},
"source": [
"결과가 잘 나왔어요."
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "c773fdba",
"metadata": {},
"outputs": [],
"source": [
"entity_ids = torch.tensor([21,21,7,17,17,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21]\n",
" ,dtype=torch.int64,device=device)"
]
},
{
"cell_type": "markdown",
"id": "b0b69822",
"metadata": {},
"source": [
"잘 이해하지는 못하겠는데, int64면 실행이 되고 int32이면 실행이 안된다.\n",
"```\n",
"RuntimeError: \"nll_loss_forward_reduce_cuda_kernel_2d_index\" not implemented for 'Int'\n",
"```\n",
"이런 오류를 내면서 죽음."
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "0f97b71a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([22, 22])"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.view(-1,emb.size(-1)).size()"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "d7d0164a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.4138, device='cuda:0', grad_fn=<NllLossBackward0>)"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)(emb.view(-1,emb.size(-1)),entity_ids)"
]
},
{
"cell_type": "markdown",
"id": "9fee44d6",
"metadata": {},
"source": [
"크로스 엔트로피를 계산하는 데에 성공.\n",
"ignore_index는 padding class를 넣어요."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "197cfa62",
"metadata": {},
"outputs": [],
"source": [
"del inputs\n",
"del entity_ids\n",
"del emb"
]
},
{
"cell_type": "markdown",
"id": "733cb548",
"metadata": {},
"source": [
"본격적으로 학습시켜봅시다."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "40a3e52c",
"metadata": {},
"outputs": [],
"source": [
"datasetTrain, datasetDev, datasetTest = readPreporcssedDataAll()"
]
},
{
"cell_type": "markdown",
"id": "66d41fee",
"metadata": {},
"source": [
"데이터 셋이 적어도 어느정도 성능이 나와하야 할지 생각해봅시다.\n",
"`O` 토큰으로 범벅이 되있으니 전부 `O`로 찍는 것 보다 좋은 성능이 나와야 하지 않겠어요?\n",
"한번 시도해봅시다."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "80c37a04",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████| 4250/4250 [00:00<00:00, 265652.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"151572/190488 = 0.7957036663726849\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"total_l = 0\n",
"total_o = 0\n",
"\n",
"for item in tqdm(datasetTrain):\n",
" entities = item[\"entity\"]\n",
" l = len(entities)\n",
" o = sum(map(lambda x: 1 if x == \"O\" else 0,entities))\n",
" total_l += l\n",
" total_o += o\n",
"\n",
"print(f\"{total_o}/{total_l} = {total_o/total_l}\")"
]
},
{
"cell_type": "markdown",
"id": "c0b67c41",
"metadata": {},
"source": [
"79% 보다 높아야 해요."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "619b959f",
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 32\n",
"train_loader = DataLoader(\n",
" DatasetArray(datasetTrain),\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=my_collate_fn\n",
")\n",
"dev_loader = DataLoader(\n",
" DatasetArray(datasetDev),\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=my_collate_fn\n",
")\n",
"test_loader = DataLoader(\n",
" DatasetArray(datasetTest),\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=my_collate_fn\n",
")"
]
},
{
"cell_type": "markdown",
"id": "7d45dd29",
"metadata": {},
"source": [
"bert paramter를 freeze 안했을땐 batch를 8 정도로 했어요. 그 이상은 메모리가 부족해서 돌아가지 않아요."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "efd1837a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"({'input_ids': tensor([[ 101, 39671, 8935, 73380, 30842, 9632, 125, 9998, 9251, 9559,\n",
" 9294, 8932, 28143, 9952, 8872, 127, 9489, 34907, 9952, 9279,\n",
" 12424, 102]]),\n",
" 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n",
" 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])},\n",
" tensor([[21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,\n",
" 21, 21, 21, 21]]))"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_collate_fn(datasetTrain[0:1])"
]
},
{
"cell_type": "markdown",
"id": "c4a2e2e1",
"metadata": {},
"source": [
"데이터를 한번 더 확인"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "c575cb55",
"metadata": {},
"outputs": [],
"source": [
"from torch.optim import AdamW"
]
},
{
"cell_type": "markdown",
"id": "627cb2f8",
"metadata": {},
"source": [
"tqdm 확인"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "56844773",
"metadata": {},
"outputs": [],
"source": [
"optimizer = AdamW(model.parameters(), lr=5.0e-5)\n",
"CELoss = nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)"
]
},
{
"cell_type": "markdown",
"id": "eaa08ab2",
"metadata": {},
"source": [
"옵티마이져 준비"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "109259b4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0: 100%|█████████████████████████████████████████| 133/133 [00:26<00:00, 4.98batch/s, accuracy=0.746, loss=1.88]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 1: 100%|█████████████████████████████████████████| 133/133 [00:26<00:00, 5.04batch/s, accuracy=0.814, loss=1.17]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 2 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 2: 100%|████████████████████████████████████████| 133/133 [00:26<00:00, 5.10batch/s, accuracy=0.821, loss=0.928]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 3 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 3: 100%|████████████████████████████████████████| 133/133 [00:26<00:00, 5.05batch/s, accuracy=0.821, loss=0.795]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 4 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 4: 5%|██▏ | 7/133 [00:01<00:30, 4.10batch/s, accuracy=0.853, loss=0.724]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_28932/1927930699.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[0mprediction\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 22\u001b[0m \u001b[1;31m#부정확 할 수 있지만 대충 맞음.[PAD]기호를 예측할 일은 없어야 함.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m \u001b[0mcorrect\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mprediction\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mbatch_labels\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 24\u001b[0m \u001b[0maccuracy\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcorrect\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mbatch_inputs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m\"attention_mask\"\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"TRAIN_EPOCH = 5\n",
"\n",
"result = []\n",
"iteration = 0\n",
"\n",
"model.zero_grad()\n",
"\n",
"for epoch in range(TRAIN_EPOCH):\n",
" model.train()\n",
" print(f\"epoch {epoch} start:\")\n",
" with tqdm(train_loader, unit=\"batch\") as tepoch:\n",
" for batch_i,batch_l in tepoch:\n",
" tepoch.set_description(f\"Epoch {epoch}\")\n",
" \n",
" batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n",
" batch_labels = batch_l.cuda(device)\n",
"\n",
" output = model(**batch_inputs)\n",
" loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n",
" \n",
" prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n",
" #부정확 할 수 있지만 대충 맞음.[PAD]기호를 예측할 일은 없어야 함.\n",
" correct = (prediction == batch_labels.view(-1)).sum().item()\n",
" accuracy = correct / batch_inputs[\"attention_mask\"].view(-1).sum()\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" \n",
" result.append({\"iter\":iteration,\"loss\":loss.item(),\"accuracy\":accuracy})\n",
" tepoch.set_postfix(loss=loss.item(), accuracy= accuracy.item())\n",
" iteration += 1"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "f0d9b2d7",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "19ca6da1",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "0bee685c",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAD4CAYAAAAdIcpQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAABZ9klEQVR4nO2dd5gUxdaH39rIBliWJefgigQJiojpmuNVMYtiuHpNKOacs2L8FBMXRb1GxCsqKmZFQAVEQSSDCLLkvDnX90dNTYfpCcvO7uzs1vs888x0T3V39YT61Tl16pSQUmIwGAwGQygSYl0Bg8FgMDR8jFgYDAaDISxGLAwGg8EQFiMWBoPBYAiLEQuDwWAwhCUpVhdOSEiQaWlpsbq8wWAwxCXFxcVSSlnvHf2YiUVaWhpFRUWxurzBYDDEJUKIklhc17ihDAaDwRAWIxYGg8FgCIsRC4PBYDCExYiFwWAwGMJixMJgMBgMYTFiYTAYDIawGLEwGAwGQ1jiUix+Xvszv6z7JdbVaNoUF0N1daxrEZ/MnAl//BHrWhhqyvr1UFkZ61rEjLgUiwNfPZChrwyNdTWaLuXlkJEBN90U65rEJ4ccAgMGxLoWhpqwfj106gQPPhjrmsSMuBOL0spS/+vC8sIY1qQJo2fejx8f23oY4pOFCyE3F7Zti945hYAbboje+dz8+qt6/vnnurtGAyfuxGLR5kX+13PXz41hTeKYd96BNWt2//gSX7YBs8qiYt06WLUq1rWIHx59FFauhC++iM75ysvV8//9X3TO58WKFeo5N7furtHAiTuxWLJ1if/1t6u+xSwLW0OqqmDkSPjHP3b/HMXF6jlexywmT4bt26N3vs6doVevyMpWVETvuvFKs2bqubQ0dLlI2bUrOucJxdKl6jkrq+6v1UCJO7E4b8B5bL5pMwAPzXiI0yedzmvzXmPKsikxrlmcoBv6vLzIypeUwFFHwfz51j7thmoIYvH663DssZGXX7oUTj8drriizqoUkvpo2BoCn30Gr7zi/V60xWLnTvWcVId5UbduVc9lZXV3jQZO3IkFQJuMNv7XHy79kIunXMzwicNjWKM4QotFQoRf/apV8O23MHt24DnKy+HJJ3evHlJGx4110UXw1VeRR6lod0JBQe2vvTvEs1gsWQLXXx9ZJ+HEE+HSS73HJbRYlEQpeaoWi91d8uDSS+Hoo0OX0b+XmoiFlPCf/6jBcTtt28K999asjg2AuBQLgHdPf5fL972c1MTUWFclvtBWQWJiZOV37FDP9j+2FguAm292NggVFfDEE+H/VMcfH1ywsrMD/7ynnaaOsaN7e2DdVzj0WE3XrsHLbNoUHSErL1fnycuDNm1g0SKnWIwfr+oRCwvtq69gyJCaucVOPhmeeQZWrw58r6BADTK/8Yba1g33558Hlo22ZaF/o/q827aputx1V2THv/IKfPMNbN4MEyZ4l9FiYa9zYWHo38nixcqCvfRSa19pKWzZAikpkdWtARG3YjGi/wjGnTiOs/uf7d9XVV0VwxrFCTW1LMKJBTgHd198EW65BZ59NvR5v/wy+Hs7d6o/r50PPwwcEL3nHut1YYSRcX/9pZ7btPF+f9kyaN9e3Ue/frsfHvzii5CaCm+9BR9/rIStf3/Yd1+rzG23wdq18P33u3eN2nDxxSrCZ+NGtb1iRWiBz8tTg9LgLW5r16rnhx9Wz9nZ6ln3+u3oRj1aLh19jWbNlJC1bq22H3+8Zuc55RS45BLLRTt/Plx3HTz3nGVZa7HYsAGaNw/9O//FNxfM/t/RHZxgv78GTNyKheauQ6zew/qC9SFKGoDdFwt7j8otFvn5geWLi2HqVGVyh+r1S7n7PUzd0EHklsW6deo5WI9ai8nHH6ue4VNPRV6fww+3Xs+apZ5/+w3S073LH3OMer7sstoPfP/9N5x/vrJczj9fXTcU+vuvqlLH7LmnqocXFRXQpYu1bW/8qqvVZ6Z72MuXw9dfq/NC4G9FH+M+T6TcfDMMHuzcZxeLTZus/cnJNTv38uXO+h19tBKDa66xyujJqBs2qG1tiRQWOv8HAHPmqOfmza19Wiy0oMURcS8WuTm5fHme6qU+/fPTMa5NPXPddZCZGbrMzp3w++/WdjQsC3fDPHas9b5u9JKSVK98yxb488/g53/zTeWy2J3QU7dLIBL0H1qHW86erVwlF12kBmW1eyBUr3fcOOWT37rVOg/AtGnWa73/mWfgu++8zyOEel61qvbRWXfeqayYxx5Tz2edFbq8/v4rKqzvzstlBFYPWWP//l95BXr2dN77McdYvwMvsdCfze6MGz35pBVssXmz+tx0A9ysmbLmNDUVC/0b0t+91zjYhx/ChRdaZQoKVLlu3QIjpXTHxD5msWWLejaWRWw4sseRHN79cF6a+xIFZTEauKwv7rvPmhj07LPhe9Rjx8JBB1k9v7pwQ02ZApdfrl7rP1hiotW7DGU5TJ6snnUDYHdxLF1qnUPz2GPKVeC2SCK1LLRY6D/7o4+qnuPrrysB0Pvt97jPPpbrBOCFF5SwTJwY2ODp+7eLyFtvedfFfuzuWldbt6pJbtrKmjFDPYfzievvv6TE+oy9GsepU5Ubxo79s17iC2UfPdpZRje8Xt+L/mw+/dQa46gpUkK7dpCTo4QS1D3ZLbRwYrF5s9NVpb97/V1oMXfz1lvW2NOaNeo6XmKvx/K0aICxLGJNYkIiDxz+AGVVZUxeMjnW1ak7pFTpBt59N/JjduxQf1jd2EcqFrNnq16QlxvKbqlo3nwTBg2y/nxJSVYjpM/hhR4IXbRI+dDtjXSfPoF+59tuU6Z/QYGqkzbx3ZZFdTV88onq5b73nrVfN9C6wSoqsvzt9vPY//zz5jktDV3Hq68OHNNwnz8UCxdar0tL1fd79dXqepFQVqZ6qHvvbTXaM2eqZy0WGzbAGWfAggXq/DNnqme7WGjXm/6+fvnF6lz8859KFO3YBaB7d++66fv3siz0Z7l+veqle1FUBA88EPxz9BKhsjJn+a1bnSHfbs4/H269NXB/OLEA+OAD7/32AW/9G9q4UXVS3nwTrr1W7TOWRew4sMuB9G/bnyd+eiLWVak7KitVI6hNWU1Fhdo/fnygH1j/MXWPWv/JwonFsGFqtqrbsli+PHj8vF1EZsyw3E+hXCzan3/PPSo658cfne9r37+b/Hz1p87JUdvuxuPVV1X0ztdfw4gRai5GVVWgG6q42CkE+jxeddbRL3a/+OuvB9bLfv5Q2GfRl5aqhv3555WFY2f6dOVesYsaWAPO+j7spKSoQeuOHVXD9umnatD9kEOUu0l//zNmqH2gfl/TpsHQoSqK5+WXvetdWKjeKy4OHzVWVKTcVBddZO2L5LN54AEVXvrOO97vuwfO27dXlugzzzj3Dx6sBGH16sC62r9HO1osQv1Hgv0H7B2jbdtUkISU6jO/4AL13xXCCgCIIxqNWCSIBM7tfy6LtixixbYV7CqN43h2L6qqrF7vli0w3DavpKRENQaXXw533+08zi0WulFJTFQCM29eYHSL/jPv2mX9ofQfyO46sTdWbj75xHq9fbtyH9kjgTR29w4o946djAzvBsktFvqz+ftv5XqzhyuCChXVPTyweqFucdXn8Yri2bRJfSahXF72ePwDD1QTGiOhtNQ6Ni1NHS+l+m7efFPV9cUXrfLFxc5oMLdY/PIL7L+/tV1drUQHVIOmG0K7FVNZad3/+PHBB7y/+069d/nl4d1nxcXKcrGLqlssXnjBGgwG9VpblPbevf13oAeYNZ07q+f//S+wDo8/Dj16BFoRwYIKIhGLYOh6Sal+9yeeqMYV7Z2gFi0iD11vQDQasQA12A2w5/N7MmBcI8vqOXQotGqlXm/ZosYJNKWlViNo949CaLG44ALlj9chqVLCpEnOeRPaWtCNqv2PHmmKi+3blfvot98ChcndyLnrn57u3RMNZlkceqhl6rvZtMlqkCdNUr117b7RbN4c/D42bgw/891uWWRkQN++zvd79HD2Slu2VM+lpZZA/fmnEtGrr1ZBDLr8ggXWcffea433gPqee/d2Xsv
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"iters = [item[\"iter\"] for item in result]\n",
"fig, ax1 = plt.subplots()\n",
"ax1.plot(iters,[item[\"loss\"] for item in result],'g')\n",
"ax2 = ax1.twinx()\n",
"ax2.plot(iters,[item[\"accuracy\"] for item in result],'r')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "2bb67740",
"metadata": {},
"source": [
"학습 그래프입니다."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "0defca72",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "2f45cae0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpu allocated : 693 MB\n",
"gpu reserved : 756MB\n"
]
}
],
"source": [
"print(f\"gpu allocated : {torch.cuda.memory_allocated() // 1024**2} MB\")\n",
"print(f\"gpu reserved : {torch.cuda.memory_reserved() // 1024 ** 2}MB\")"
]
},
{
"cell_type": "markdown",
"id": "68f18f4d",
"metadata": {},
"source": [
"gpu 메모리 사용량을 보는 코드"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "73b01630",
"metadata": {},
"outputs": [],
"source": [
"del batch_inputs\n",
"del batch_labels\n",
"del loss\n",
"del optimizer"
]
},
{
"cell_type": "markdown",
"id": "02a3f367",
"metadata": {},
"source": [
"한번 테스트 해보자"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "af93a3ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"김광현 을 둘러싼 주변 의 우려 는 너무 많이 던진다는 것 .\n",
"두산 은 주포 인 김동주 와 안경현 이 빠진 상황 이 라 타력 에 적 지 않 은 문제점 을 안 고 있 지만 29 일 잠실 롯데전 이 우천 으로 취소 되 는 바람 에 시간 을 벌 었 고 더구나 삼성 이 4 연패 로 2 위 로 추락 해 어부지리 로 1 위 로 올라서 는 행운 을 잡 았 다 .\n"
]
}
],
"source": [
"for data in datasetTrain[100:102]:\n",
" print(tokenizer.convert_tokens_to_string(data[\"tokens\"]))"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "9b2cd5c4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[ 101, 8935, 118649, 30842, 9633, 9105, 30873, 119091, 9689,\n",
" 118985, 9637, 9604, 26737, 9043, 9004, 32537, 47058, 9076,\n",
" 65096, 11018, 8870, 119, 102, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0],\n",
" [ 101, 9102, 21386, 9632, 9689, 55530, 9640, 8935, 18778,\n",
" 16323, 9590, 9521, 31720, 30842, 9638, 9388, 18623, 9414,\n",
" 65649, 9638, 9157, 9845, 28143, 9559, 9664, 9706, 9523,\n",
" 9632, 9297, 17730, 34907, 9633, 9521, 8888, 9647, 9706,\n",
" 19105, 10386, 9641, 9655, 31503, 9208, 28911, 16617, 9638,\n",
" 9604, 38631, 29805, 9773, 22333, 9098, 9043, 9318, 61250,\n",
" 9559, 9485, 18784, 9633, 9339, 9557, 8888, 9074, 17196,\n",
" 16439, 9410, 17138, 9638, 125, 9568, 119383, 9202, 123,\n",
" 9619, 9202, 9765, 107693, 9960, 9546, 14646, 12508, 12692,\n",
" 9202, 122, 9619, 9202, 9583, 17342, 12424, 9043, 9966,\n",
" 21614, 9633, 9656, 9529, 9056, 119, 102]],\n",
" device='cuda:0'),\n",
" 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0],\n",
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1]], device='cuda:0'),\n",
" 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0]], device='cuda:0')}"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs, labels = my_collate_fn(datasetTrain[100:102])\n",
"inputs = {k:v.to(device) for k,v in inputs.items()}\n",
"inputs"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "6b31782c",
"metadata": {},
"outputs": [],
"source": [
"model.eval()\n",
"with torch.no_grad():\n",
" predict_label = model(**inputs)\n",
" sp = model.softmax(predict_label)\n",
" p = sp.argmax(dim=-1,keepdim=True).squeeze().cpu()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "7f4d43ce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['O', 'O', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n",
"['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n"
]
}
],
"source": [
"for data in p.numpy():\n",
" print(tagIdConverter.convert_ids_to_tokens(data))"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "5ade3317",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n",
"['O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DT', 'I-DT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n"
]
}
],
"source": [
"for data in labels.squeeze().numpy():\n",
" print(tagIdConverter.convert_ids_to_tokens(data))"
]
},
{
"cell_type": "markdown",
"id": "b48e22ff",
"metadata": {},
"source": [
"It work!"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "383dd24a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([194, 1])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sp.cpu().view(-1,sp.size(-1)).argmax(dim=-1,keepdim=True).size()"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "ff74fced",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(106)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"correct = (sp.cpu().view(-1,sp.size(-1)).argmax(dim=-1) == labels.view(-1)).sum()\n",
"correct"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "3f6ad5d8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(120, device='cuda:0')"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[\"attention_mask\"].view(-1).sum()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "986fd52b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.8833, device='cuda:0')"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy = correct / inputs[\"attention_mask\"].view(-1).sum()\n",
"accuracy"
]
},
{
"cell_type": "markdown",
"id": "8cc47437",
"metadata": {},
"source": [
"accuracy는 다음과 같이 구해져요."
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "1f3f8666",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████| 16/16 [00:02<00:00, 5.66batch/s]\n"
]
}
],
"source": [
"model.eval()\n",
"collect_list = []\n",
"with torch.no_grad():\n",
" with tqdm(test_loader, unit=\"batch\") as tepoch:\n",
" for batch_i,batch_l in tepoch:\n",
" batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n",
" batch_labels = batch_l.cuda(device)\n",
" output = model(**batch_inputs)\n",
" loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n",
" \n",
" prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n",
" correct = (prediction == batch_labels.view(-1)).sum().item()\n",
" accuracy = correct / batch_inputs[\"attention_mask\"].view(-1).sum()\n",
" \n",
" collect_list.append({\"loss\":loss.item(),\"accuracy\":accuracy, \"batch_size\":batch_labels.size(0),\n",
" \"predict\":output.argmax(dim=-1).cpu(),\n",
" \"actual\":batch_labels.cpu(),\n",
" \"attention_mask\":batch_inputs[\"attention_mask\"].cpu()})"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "b7567f48",
"metadata": {},
"outputs": [],
"source": [
"def getConfusionMatrix(predict,actual,attention_mask):\n",
" ret = torch.zeros((22,22),dtype=torch.long)\n",
" for i,(p_s,a_s) in enumerate(zip(predict,actual)):\n",
" for j,(p,a) in enumerate(zip(p_s,a_s)):\n",
" ret[p,a] += attention_mask[i,j]\n",
" return ret"
]
},
{
"cell_type": "markdown",
"id": "a898cd34",
"metadata": {},
"source": [
"단순하게 confusion matrix를 계산하는 함수. 클래스 22개 정해져 있음."
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "15cd73a5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"getConfusionMatrix(torch.tensor([[0,1]]),torch.tensor([[1,1]]),torch.tensor([[1,1]]))"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "de9c7932",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average_loss : 0.7903580265045166, average_accuracy : 0.8021594285964966, size :500\n"
]
}
],
"source": [
"total_loss = 0\n",
"total_accuracy = 0\n",
"total_size = 0\n",
"confusion = torch.zeros((22,22),dtype=torch.long)\n",
"\n",
"for item in collect_list:\n",
" batch_size = item[\"batch_size\"]\n",
" total_loss += batch_size * item[\"loss\"]\n",
" total_accuracy += batch_size * item[\"accuracy\"]\n",
" total_size += batch_size\n",
" confusion += getConfusionMatrix(item[\"predict\"],item[\"actual\"],item[\"attention_mask\"])\n",
"print(f\"\"\"average_loss : {total_loss/total_size}, average_accuracy : {total_accuracy/total_size}, size :{total_size}\"\"\")"
]
},
{
"cell_type": "markdown",
"id": "24539b98",
"metadata": {},
"source": [
"test로 보면 결과가 나왔어요. 84% 나와요. F1 스코어는 아직입니다."
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "f6047991",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 2, 2, 27, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0],\n",
" [ 0, 0, 1, 0, 167, 126, 466, 421, 40, 0,\n",
" 0, 0, 5, 0, 375, 166, 1005, 1154, 101, 0,\n",
" 0, 16409]])"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "000d1e68",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATQAAAEICAYAAADROQhJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAxhElEQVR4nO2deZwVxbXHv2dYlV2GfUdGlmEZZhD0YeIWNzQmGo0SY8xTH2LEbBqfxCQueURjfMl7ionB6DMxRKNxiQtRiSYuMVFBATGKqAEFkUU2kQEHOO+Prhmay517+259q5r68anPvd19un59+vYUVXXOr0tUFQ8PD48koKLcF+Dh4eFRLPgGzcPDIzHwDZqHh0di4Bs0Dw+PxMA3aB4eHomBb9A8PDwSA9+gJRAS4P9EZIOIvFhAPZ8SkSXFvLZyQUT6i8gWEWlR7mvxKB3E56ElDyLyKeAuYKiqflzu6yk1RGQZcL6q/rnc1+JRXvgeWjIxAFi2LzRmUSAiLct9DR7xwDdoZYaI9BOR+0VkrYh8KCIzzf4KEfmeiCwXkTUi8hsR6WSODRQRFZFzRORdEVknIleYY+cBvwIONUOsq0XkqyLyXAqvisgQ832SiPxTRD4SkZUicqnZf4SIrAidM1xE/ioiG0XkNRE5OXTsDhG5WUQeNfW8ICIHNuNz4/X/u4i8Z4bGU0XkYBFZZOqfGbI/UESeMvdnnYjMFpHO5tidQH/gYePvZaH6zxORd4GnQvtaisgBIrJCRD5r6mgvIm+JyFcK/T09ygxV9aVMBWgBLAR+BrQD2gKHmWPnAm8Bg4H2wP3AnebYQECBW4H9gDHAdmC4Of5V4LkQzx7bZp8CQ8z3VcCnzPcuQK35fgSwwnxvZa7nu0Br4CjgI4JhLcAdwHpgPNASmA3c3Yzfjdd/i/H5WGAb8CDQHegDrAEON/ZDgGOANkA34Bngf0L1LQM+k6b+35j7ul9oX0tjcyzwgeG7FfhDuZ8HXwovvodWXowHegPfUdWPVXWbqjb2pM4Cfqqq76jqFmA6cGbK8OlqVa1X1YUEDeOYPK+jARghIh1VdYOqvpzG5hCChvU6Vf1EVZ8CHgEmh2zuV9UXVXUHQYNWk4X3h8bnJ4CPgbtUdY2qrgSeBcYCqOpbqjpXVber6lrgp8DhEfy6ytzX+tQDhvNe4EngROCCCPV5WA7foJUX/YDlpgFIRW9geWh7OUHPp0do3weh71sJGpx88AVgErBcRJ4WkUObuZ73VHVXyjX1KeB6Voe+16fZbg8gIt1F5G4zHN4M/BaozFI3wHtZjs8CRgL/p6ofRqjPw3L4Bq28eA/o38yk9fsEk/uN6A/sYM8/+qj4GNi/cUNEeoYPqupLqvo5guHXg8A9zVxPPxEJPzP9gZV5XE+uuJZguDhaVTsCXwYkdLy5UH2zIXyTvvFLgmHphY3ziR5uwzdo5cWLBPNX14lIOxFpKyITzbG7gG+JyCARaQ/8CPh9M725bFgIVItIjYi0Ba5qPCAirUXkLBHppKoNwGZgZ5o6XiBoGC8TkVYicgTwWeDuPK4nV3QAtgAbRaQP8J2U46sJ5hpzwXfN57nADcBvfI6a+/ANWhmhqjsJGoUhwLvACuAMc/h24E6CCfB/EUyaX5wnz5vANcCfgaXAcykmZwPLzHBuKkEPKLWOT4CTgROAdcDPga+o6hv5XFOOuBqoBTYBjxIESMK4FvieiY5emq0yEakDvk1w/TuBHxP05i4v6lV7xA6fWOvh4ZEY+B6ah4dHYuAbNA8Pj8TA+gZNRI4XkSUmkzvjHEdU22Lbee5kcCfNn1JwW49yZ/ZmKgSZ9G8TRLBaE0TrRhRiW2w7z50M7qT5UwpuF4qVQYHKykodMGAgW7ZsYdWqVVRVVQHwwQerAOjZs9de50S1Lbad504Gd9L8yWa3fPky1q1bJwAmkfoqVT3ObE8HUNVr97rQEFp0HKC6Yy8RRlpo/drHVfX4SMaFoNwtarpSW1un9Q2qs+++V7/67+dpfYNqfYPqbf/3G73gwouatsMlqm2x7Tx3MriT5k82u9raOg310E4DfhXaPhuYmbU3tF93bTv24kgFmBdH21HQHFq2cbcEuNEcXyQitbk2tmnqLMi22HaeOxncpajTFW72VF00VdGc8R5niUQrMSHvBs1kVd9MkGg5ApgsIiNSzE4AqkyZAvwiF44+ffqyYsVuOd7KlSvo3bt3QbbFtvPcyeBOmj+5cBMkdPcLbfclkLplh1REK3Eh364dcCjweGh7OjA9xeaXwOTQ9hKgV9Qh50f1DTpw0CB9/c13dNPH23XUqNE6f8HitF32qLbFtvPcyeBOmj/Z7FKGnC2Bd4BB7A4KVGcdcu7fXduO+1akQkxDzkLe5NmHPd9msAKYEMGmD4F+MStatmzJz/53Jp898Th27tzJOV89lxHV1QXZFtvOcyeDO2n+5MKtqjtEZBrwOEHE83ZVfS2t8R4QqLBL/pp3lFNETgeOU9XzzfbZwHhVvThk8yhwrZp3fInIk8Blqjo/TX1TCIal9Ovfv+7Nt5enmnh4eBQJEyeMY/78eQVNblW076ltRp4TyXbbC9fPV9VxhfBFQSGD2yjj7shjc1WdparjVHVct8puBVyWh4dHPIgYEHAhKAC8BFRJ8Hqb1sCZwEMpNg8BXzHRzkOATaoaabjp4eHhACwLCuTNpMF7uRrH3a8D96jqaxIsdjHVmM0hmGx8i+C97V/LleeJxx9jdPVQqocN4SfXX1cU22Lbee5kcCfNn1y485Y+WdZDK3nUIZ/SGOXcsm2HDho8WP+55O2mSM3LC19LGyWKaltsO8+dDO6k+ZPNLiXKmZf0Sdr10rYTvxep4EJibanx0osvcuCBQxg0eDCtW7fm9DPO5JGH/1iQbbHtPHcyuJPmTy7cBIv1vKXBgjyfELyF+HPNGTdBCKKcUUpMsLpBe//9lfTtuzum0KdPX1auTP8K+6i2xbbz3MngTpo/uXDTfHpVFkhy5tDiQLqUEhulI57bfe5S1OkKN/lKnwAqJFrJAAkW2/6LiLwuwQLW3zD7r5Jgpa8FpkzKdjmFJNaWHK5IRzy3+9xJ8ycW6ZNQrN7XDuASVX1ZRDoA80Vkrjn2M1W9IXJN5Q4AZAoK2Cwd8dzJ4k6aP7FInzr01rZHzYhUyCEoAPwROIZgdbJLc2k7rO6huSId8dzucyfNHwulT5UiMi+0PUtVZ+1Vo8hAYCzBsokTgWki8hVgHkEvbkPGK0o31i436urG6d9emJfd0MPDIy8URfrUsa+2OeQbkWy3zb0sq/RJgvVnnwZmqOr9ItKDYMlEBX5I8GKLczPVYXUPzcPDw2IUMWlWRFoB9wGzVfV+AFVdHTp+K/BItnqsjnKCO5nWntt97qT5E49SoPC0DQnCr7cBr6vqT0P7w+8VPwVYnPV6yh0AyBQUsDnT2nMniztp/sSiFOjYV9se/9NIhQxBAeAwgmHlImCBKZOAO4FXzf6HiPAuRat7aK5kWntu97mT5k8sSoEiJdaq6nOqKqo6WlVrTJmjqmer6iiz/2SN8GILqxs0VzKtPbf73EnzJxalgIXSJ6uDAukisDZmWntu97lLUacr3OStFJBYZU1RYHWD5kqmted2nztp/sS3SEqMrwaKgnIHADIFBWzOtPbcyeJOmj+xKAU69de2J98SqeDAIiklhyuZ1p7bfe6k+ROPUgDremheKeDhsQ+iKEqBLgO1zRHfi2S77cH/sHuRlOZe+ZFic4SIbAq9/uMHhV2uh4eHTZCKikglLhTC1PjKj+HAIcBFaVZOB3g2lFtyTa4krmRae273uZPmT6mVAhKcF6nEhmJNxmFe+ZGy7wjgkXyDAjZnWnvuZHEnzZ84lAIVXQbo/qfdHqng0poCKa/8SMWhIrJQRP4kIulnJYM6pojIPBGZt3bdWsCdTGvP7T530vyJSylgWw+t4AbNvPLjPuCbqro55fDLwABVHQPcBDzYXD2aZqFhVzKtPbf73EnzJ541BewbchbUoKV75UcYqrpZVbeY73OAViJSGbX+dBFYGzOtPbf73KWo0xVuClhToKK
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import itertools\n",
"\n",
"plt.title(\"confusion matrix\")\n",
"plt.imshow(confusion[:21,:21],cmap='Blues')\n",
"\n",
"plt.colorbar()\n",
"for i,j in itertools.product(range(confusion.shape[0]),range(confusion.shape[1])):\n",
" plt.text(j,i,\"{:}\".format(confusion[i,j]),horizontalalignment=\"center\",color=\"black\" if i == j else \"black\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "543ede4b",
"metadata": {},
"source": [
"혼동행렬을 보면 별로인것 같습니다.\n",
"왼쪽이 predict이고 밑이 actual입니다."
]
},
{
"cell_type": "markdown",
"id": "9e85493d",
"metadata": {},
"source": [
"$$Precision(class = a) = \\frac{TP(class = a)}{TP(class = a)+FP(class = a)}$$\n",
"\n",
"$$Recall(class = a) = \\frac{TP(class = a)}{TP(class = a)+FN(class = a)}$$\n",
"\n",
"$$F1Score(class = a) = \\frac{2}{\\frac{1}{Precision(class = a)}+\\frac{1}{Recall(class = a)} }$$"
]
},
{
"cell_type": "markdown",
"id": "200c95b9",
"metadata": {},
"source": [
"F1Score는 다음과 같이 주어집니다."
]
},
{
"cell_type": "markdown",
"id": "db75fb9f",
"metadata": {},
"source": [
"O 클래스에 대해서 계산해보면"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "86f318c4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(16409)"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"TP = confusion[21,21]\n",
"TP"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "60f8af59",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(4027)"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"FP = confusion[21].sum() - TP\n",
"FP"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "0b5d4cd9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0)"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"FN = confusion[:,21].sum() - TP\n",
"FN"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "5d88f758",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"precision : 0.8029457926750183\n",
"recall : 1.0\n",
"F1Score : 0.8907042741775513\n"
]
}
],
"source": [
"precision = TP / (TP + FP)\n",
"recall = TP / (TP + FN)\n",
"\n",
"f1Score = (2*precision*recall)/(precision + recall)\n",
"print(f\"precision : {precision}\")\n",
"print(f\"recall : {recall}\")\n",
"print(f\"F1Score : {f1Score}\")"
]
},
{
"cell_type": "markdown",
"id": "0b23e7d5",
"metadata": {},
"source": [
"다른 클래스에 대해서도 모두 해보자"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "38b3eee6",
"metadata": {},
"outputs": [],
"source": [
"def getF1Score(confusion,c):\n",
" TP = confusion[c,c]\n",
" FP = confusion[c].sum() - TP\n",
" FN = confusion[:,c].sum() - TP\n",
" precision = TP / (TP + FP)\n",
" recall = TP / (TP + FN)\n",
"\n",
" f1Score = (2*precision*recall)/(precision + recall)\n",
" return f1Score"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "61fe2d6c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"class 0 f1 score : nan\n",
"class 1 f1 score : nan\n",
"class 2 f1 score : nan\n",
"class 3 f1 score : nan\n",
"class 4 f1 score : nan\n",
"class 5 f1 score : nan\n",
"class 6 f1 score : nan\n",
"class 7 f1 score : nan\n",
"class 8 f1 score : nan\n",
"class 9 f1 score : nan\n",
"class 10 f1 score : nan\n",
"class 11 f1 score : nan\n",
"class 12 f1 score : nan\n",
"class 13 f1 score : nan\n",
"class 14 f1 score : nan\n",
"class 15 f1 score : nan\n",
"class 16 f1 score : 0.001982160611078143\n",
"class 17 f1 score : 0.0445544570684433\n",
"class 18 f1 score : nan\n",
"class 19 f1 score : nan\n",
"class 20 f1 score : nan\n",
"class 21 f1 score : 0.8907042741775513\n"
]
}
],
"source": [
"for i in range(22):\n",
" f1 = getF1Score(confusion,i)\n",
" print(f\"class {i} f1 score : {f1}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b9b55e7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}