1975 lines
65 KiB
Plaintext
1975 lines
65 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "1cc9cec9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
|
|
]
|
|
},
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import sys\n",
|
|
"sys.executable"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f9c786f1",
|
|
"metadata": {},
|
|
"source": [
|
|
"파이썬 환경 확인.\n",
|
|
"envs\\\\nn\\\\python.exe 으로 끝나기를 기대합니다"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "4c0a08d7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 9.06it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from time import sleep\n",
|
|
"from tqdm import tqdm, trange\n",
|
|
"\n",
|
|
"lst = [i for i in range(5)]\n",
|
|
"\n",
|
|
"for element in tqdm(lst):\n",
|
|
" sleep(0.1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "56a66a44",
|
|
"metadata": {},
|
|
"source": [
|
|
"### tqdm 소개\n",
|
|
"tqdm은 다음과 같이 Progress bar 그려주는 라이브러리이에요. 와! 편하다.\n",
|
|
"```\n",
|
|
"from tqdm.auto import tqdm\n",
|
|
"from tqdm.notebook import tqdm\n",
|
|
"```\n",
|
|
"이건 0에서 멈춰이있고 작동하지 않더라고요. 왜인진 몰라요."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "d00b25e6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from preprocessing import readPreporcssedDataAll\n",
|
|
"import torch\n",
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
"from dataset import make_collate_fn, DatasetArray\n",
|
|
"from transformers import BertTokenizer\n",
|
|
"import torch.nn as nn\n",
|
|
"from read_data import TagIdConverter"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a1d7bb15",
|
|
"metadata": {},
|
|
"source": [
|
|
"대충 필요한 것 임포트"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "433419f4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tagIdConverter = TagIdConverter()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "79fb54df",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cuda available : True\n",
|
|
"available device count : 1\n",
|
|
"device name: NVIDIA GeForce RTX 3070\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"cuda available :\",torch.cuda.is_available())\n",
|
|
"print(\"available device count :\",torch.cuda.device_count())\n",
|
|
"\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" device_index = torch.cuda.current_device()\n",
|
|
" print(\"device name:\",torch.cuda.get_device_name(device_index))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "dd22dd5e",
|
|
"metadata": {},
|
|
"source": [
|
|
"cuda가 가능한지 먼저 확인해보아요."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "2cd9fe37",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
|
|
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)\n",
|
|
"\n",
|
|
"my_collate_fn = make_collate_fn(tokenizer, tagIdConverter)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2177c793",
|
|
"metadata": {},
|
|
"source": [
|
|
"데이터 로딩 준비"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "e738062d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import BertModel"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "70296ee3",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
|
|
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
|
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
|
|
"bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "6cbc236d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class MyModel(nn.Module):\n",
|
|
" def __init__(self,output_feat: int,bert):\n",
|
|
" super().__init__()\n",
|
|
" self.bert = bert\n",
|
|
" self.dropout = nn.Dropout(p=0.1)\n",
|
|
" self.lin = nn.Linear(768,output_feat) #[batch_size,word_size,768] -> [batch_size,word_size,output_feat]\n",
|
|
" self.softmax = nn.Softmax(2) #[batch_size,word_size,output_feat] -> [batch_size,word_size,output_feat]\n",
|
|
" #0부터 시작해서 2 번째 차원에 softmax.\n",
|
|
"\n",
|
|
" def forward(self,**kargs):\n",
|
|
" emb = self.bert(**kargs)\n",
|
|
" e = self.dropout(emb['last_hidden_state'])\n",
|
|
" w = self.lin(e)\n",
|
|
" return w"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7dec337a",
|
|
"metadata": {},
|
|
"source": [
|
|
"모델 선언\n",
|
|
"`nn.CrossEntropy`는 소프트맥스를 겸함."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "e5214de8",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"MyModel(\n",
|
|
" (bert): BertModel(\n",
|
|
" (embeddings): BertEmbeddings(\n",
|
|
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
|
|
" (position_embeddings): Embedding(512, 768)\n",
|
|
" (token_type_embeddings): Embedding(2, 768)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (encoder): BertEncoder(\n",
|
|
" (layer): ModuleList(\n",
|
|
" (0): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (1): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (2): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (3): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (4): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (5): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (6): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (7): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (8): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (9): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (10): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (11): BertLayer(\n",
|
|
" (attention): BertAttention(\n",
|
|
" (self): BertSelfAttention(\n",
|
|
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" (output): BertSelfOutput(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (intermediate): BertIntermediate(\n",
|
|
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
|
" )\n",
|
|
" (output): BertOutput(\n",
|
|
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
|
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (pooler): BertPooler(\n",
|
|
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
|
" (activation): Tanh()\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
" (lin): Linear(in_features=768, out_features=22, bias=True)\n",
|
|
" (softmax): Softmax(dim=2)\n",
|
|
")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = MyModel(22,bert)\n",
|
|
"model.cuda()\n",
|
|
"bert.cuda()\n",
|
|
"print(model)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "72df7ac3",
|
|
"metadata": {},
|
|
"source": [
|
|
"Tag의 종류가 22가지 입니다. 그래서 22을 넣었어요."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0bba477c",
|
|
"metadata": {},
|
|
"source": [
|
|
"생성과 동시에 gpu로 옮기자.\n",
|
|
"`cuda` 저거 호출하면 됨."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "5aa129e3",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"bert current device : cuda:0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"bert current device :\",bert.device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "61d356b6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#for param in bert.parameters():\n",
|
|
"# param.requires_grad = False"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cf6690b2",
|
|
"metadata": {},
|
|
"source": [
|
|
"bert 는 업데이트 하지 않는다. 메모리를 아낄 수 있다."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "f28a1a61",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"device(type='cuda')"
|
|
]
|
|
},
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"device = torch.device(\"cuda\")\n",
|
|
"device"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"id": "8844beef",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"inputs = {'input_ids': torch.tensor([[ 101, 39671, 8935, 73380, 30842, 9632, 125, 9998, 9251, 9559,\n",
|
|
" 9294, 8932, 28143, 9952, 8872, 127, 9489, 34907, 9952, 9279,\n",
|
|
" 12424, 102]],device=device), \n",
|
|
" 'token_type_ids': torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],device=device), \n",
|
|
" 'attention_mask': torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],device=device)}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4046945a",
|
|
"metadata": {},
|
|
"source": [
|
|
"적당한 인풋을 정의한다"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"id": "c37c3c1b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"emb = model(**inputs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"id": "261d4cc7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([1, 22, 22])"
|
|
]
|
|
},
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"emb.size()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d492e37b",
|
|
"metadata": {},
|
|
"source": [
|
|
"결과가 잘 나왔어요."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"id": "c773fdba",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"entity_ids = torch.tensor([21,21,7,17,17,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21]\n",
|
|
" ,dtype=torch.int64,device=device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b0b69822",
|
|
"metadata": {},
|
|
"source": [
|
|
"잘 이해하지는 못하겠는데, int64면 실행이 되고 int32이면 실행이 안된다.\n",
|
|
"```\n",
|
|
"RuntimeError: \"nll_loss_forward_reduce_cuda_kernel_2d_index\" not implemented for 'Int'\n",
|
|
"```\n",
|
|
"이런 오류를 내면서 죽음."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"id": "0f97b71a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([22, 22])"
|
|
]
|
|
},
|
|
"execution_count": 32,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"emb.view(-1,emb.size(-1)).size()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "e18b33a5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([19, 13, 21, 21, 19, 21, 13, 13, 16, 16, 21, 16, 19, 13, 2, 13, 13, 16,\n",
|
|
" 13, 20, 19, 19], device='cuda:0')"
|
|
]
|
|
},
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"predict = emb.view(-1,emb.size(-1)).argmax(dim=-1)\n",
|
|
"predict"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "3312e199",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([False, False, False, False, False, True, False, False, False, False,\n",
|
|
" True, False, False, False, False, False, False, False, False, False,\n",
|
|
" False, False], device='cuda:0')"
|
|
]
|
|
},
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"(predict == entity_ids)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "97270036",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n",
|
|
" device='cuda:0')"
|
|
]
|
|
},
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"(predict == entity_ids) * inputs[\"attention_mask\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "d7d0164a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(2.9564, device='cuda:0', grad_fn=<NllLossBackward0>)"
|
|
]
|
|
},
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)(emb.view(-1,emb.size(-1)),entity_ids)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9fee44d6",
|
|
"metadata": {},
|
|
"source": [
|
|
"크로스 엔트로피를 계산하는 데에 성공.\n",
|
|
"ignore_index는 padding class를 넣어요."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"id": "197cfa62",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"del inputs\n",
|
|
"del entity_ids\n",
|
|
"del emb"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "733cb548",
|
|
"metadata": {},
|
|
"source": [
|
|
"본격적으로 학습시켜봅시다."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "40a3e52c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"datasetTrain, datasetDev, datasetTest = readPreporcssedDataAll()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "66d41fee",
|
|
"metadata": {},
|
|
"source": [
|
|
"데이터 셋이 적어도 어느정도 성능이 나와하야 할지 생각해봅시다.\n",
|
|
"`O` 토큰으로 범벅이 되있으니 전부 `O`로 찍는 것 보다 좋은 성능이 나와야 하지 않겠어요?\n",
|
|
"한번 시도해봅시다."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "80c37a04",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████████████████████████████████████████████████████████████████████| 4250/4250 [00:00<00:00, 283367.38it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"151572/190488 = 0.7957036663726849\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"total_l = 0\n",
|
|
"total_o = 0\n",
|
|
"\n",
|
|
"for item in tqdm(datasetTrain):\n",
|
|
" entities = item[\"entity\"]\n",
|
|
" l = len(entities)\n",
|
|
" o = sum(map(lambda x: 1 if x == \"O\" else 0,entities))\n",
|
|
" total_l += l\n",
|
|
" total_o += o\n",
|
|
"\n",
|
|
"print(f\"{total_o}/{total_l} = {total_o/total_l}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c0b67c41",
|
|
"metadata": {},
|
|
"source": [
|
|
"79% 보다 높아야 해요."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "619b959f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"BATCH_SIZE = 4\n",
|
|
"train_loader = DataLoader(\n",
|
|
" DatasetArray(datasetTrain),\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" shuffle=True,\n",
|
|
" collate_fn=my_collate_fn\n",
|
|
")\n",
|
|
"dev_loader = DataLoader(\n",
|
|
" DatasetArray(datasetDev),\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" shuffle=True,\n",
|
|
" collate_fn=my_collate_fn\n",
|
|
")\n",
|
|
"test_loader = DataLoader(\n",
|
|
" DatasetArray(datasetTest),\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" shuffle=True,\n",
|
|
" collate_fn=my_collate_fn\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7d45dd29",
|
|
"metadata": {},
|
|
"source": [
|
|
"BATCH_SIZE 를 4로 잡는다.\n",
|
|
"bert paramter를 freeze 안했을땐 batch를 8 정도로 했어요. 그 이상은 메모리가 부족해서 돌아가지 않아요.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "efd1837a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"({'input_ids': tensor([[ 101, 39671, 8935, 73380, 30842, 9632, 125, 9998, 9251, 9559,\n",
|
|
" 9294, 8932, 28143, 9952, 8872, 127, 9489, 34907, 9952, 9279,\n",
|
|
" 12424, 102]]),\n",
|
|
" 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n",
|
|
" 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])},\n",
|
|
" tensor([[21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,\n",
|
|
" 21, 21, 21, 21]]))"
|
|
]
|
|
},
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"my_collate_fn(datasetTrain[0:1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c4a2e2e1",
|
|
"metadata": {},
|
|
"source": [
|
|
"데이터를 한번 더 확인"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "c575cb55",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torch.optim import AdamW"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "627cb2f8",
|
|
"metadata": {},
|
|
"source": [
|
|
"tqdm 확인"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "56844773",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"optimizer = AdamW(model.parameters(), lr=1.0e-5)\n",
|
|
"CELoss = nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "eaa08ab2",
|
|
"metadata": {},
|
|
"source": [
|
|
"옵티마이져 준비"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "41b62321",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[1, 2, 3, 4]\n",
|
|
"[5, 6, 7, 8]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from groupby_index import groupby_index\n",
|
|
"\n",
|
|
"for g in groupby_index([1,2,3,4,5,6,7,8],4):\n",
|
|
" print([*g])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "faf180cb",
|
|
"metadata": {},
|
|
"source": [
|
|
"`groupby_index` 그룹으로 묶어서 실행"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"id": "109259b4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 0 start:\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 0: 100%|███████████████████████████████████████| 1063/1063 [00:45<00:00, 23.15batch/s, accuracy=0.923, loss=2.24]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 1 start:\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1: 100%|███████████████████████████████████████| 1063/1063 [00:46<00:00, 23.07batch/s, accuracy=0.961, loss=1.52]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 2 start:\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 2: 100%|██████████████████████████████████████| 1063/1063 [00:46<00:00, 23.06batch/s, accuracy=0.976, loss=0.793]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 3 start:\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 3: 100%|███████████████████████████████████████| 1063/1063 [00:46<00:00, 23.07batch/s, accuracy=0.935, loss=1.88]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 4 start:\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 4: 100%|███████████████████████████████████████| 1063/1063 [00:46<00:00, 23.09batch/s, accuracy=0.98, loss=0.524]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"TRAIN_EPOCH = 5\n",
|
|
"\n",
|
|
"result = []\n",
|
|
"iteration = 0\n",
|
|
"\n",
|
|
"t = []\n",
|
|
"\n",
|
|
"model.zero_grad()\n",
|
|
"\n",
|
|
"for epoch in range(TRAIN_EPOCH):\n",
|
|
" model.train()\n",
|
|
" print(f\"epoch {epoch} start:\")\n",
|
|
" with tqdm(train_loader, unit=\"minibatch\") as tepoch:\n",
|
|
" tepoch.set_description(f\"Epoch {epoch}\")\n",
|
|
" \n",
|
|
" for batch in groupby_index(tepoch,8):\n",
|
|
" corrects = 0\n",
|
|
" totals = 0\n",
|
|
" losses = 0\n",
|
|
" \n",
|
|
" optimizer.zero_grad()\n",
|
|
" for mini_i,mini_l in batch:\n",
|
|
" batch_inputs = {k: v.cuda(device) for k, v in list(mini_i.items())}\n",
|
|
" batch_labels = mini_l.cuda(device)\n",
|
|
" attention_mask = batch_inputs[\"attention_mask\"]\n",
|
|
" \n",
|
|
" output = model(**batch_inputs)\n",
|
|
" loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n",
|
|
" \n",
|
|
" prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n",
|
|
" corrects += ((prediction == batch_labels.view(-1)) * attention_mask.view(-1)).sum().item()\n",
|
|
" totals += attention_mask.view(-1).sum().item()\n",
|
|
" losses += loss.item()\n",
|
|
" loss.backward()\n",
|
|
"\n",
|
|
" optimizer.step()\n",
|
|
" accuracy = corrects / totals\n",
|
|
" result.append({\"iter\":iteration,\"loss\":losses,\"accuracy\":accuracy})\n",
|
|
" tepoch.set_postfix(loss=losses, accuracy= accuracy)\n",
|
|
" iteration += 1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"id": "f0d9b2d7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%matplotlib inline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"id": "19ca6da1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0bee685c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"iters = [item[\"iter\"] for item in result]\n",
|
|
"fig, ax1 = plt.subplots()\n",
|
|
"ax1.plot(iters,[item[\"loss\"] for item in result],'g')\n",
|
|
"ax2 = ax1.twinx()\n",
|
|
"ax2.plot(iters,[item[\"accuracy\"] for item in result],'r')\n",
|
|
"plt.xlabel(\"iter\")\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2bb67740",
|
|
"metadata": {},
|
|
"source": [
|
|
"학습 그래프입니다."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"id": "0defca72",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"torch.cuda.empty_cache()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "2f45cae0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"gpu allocated : 678 MB\n",
|
|
"gpu reserved : 1446MB\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f\"gpu allocated : {torch.cuda.memory_allocated() // 1024**2} MB\")\n",
|
|
"print(f\"gpu reserved : {torch.cuda.memory_reserved() // 1024 ** 2}MB\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "68f18f4d",
|
|
"metadata": {},
|
|
"source": [
|
|
"gpu 메모리 사용량을 보는 코드"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 49,
|
|
"id": "73b01630",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"del batch_inputs\n",
|
|
"del batch_labels\n",
|
|
"del loss\n",
|
|
"del optimizer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "02a3f367",
|
|
"metadata": {},
|
|
"source": [
|
|
"한번 테스트 해보자"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "af93a3ec",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"김광현 을 둘러싼 주변 의 우려 는 너무 많이 던진다는 것 .\n",
|
|
"두산 은 주포 인 김동주 와 안경현 이 빠진 상황 이 라 타력 에 적 지 않 은 문제점 을 안 고 있 지만 29 일 잠실 롯데전 이 우천 으로 취소 되 는 바람 에 시간 을 벌 었 고 더구나 삼성 이 4 연패 로 2 위 로 추락 해 어부지리 로 1 위 로 올라서 는 행운 을 잡 았 다 .\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for data in datasetTrain[100:102]:\n",
|
|
" print(tokenizer.convert_tokens_to_string(data[\"tokens\"]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"id": "9b2cd5c4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input_ids': tensor([[ 101, 8935, 118649, 30842, 9633, 9105, 30873, 119091, 9689,\n",
|
|
" 118985, 9637, 9604, 26737, 9043, 9004, 32537, 47058, 9076,\n",
|
|
" 65096, 11018, 8870, 119, 102, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [ 101, 9102, 21386, 9632, 9689, 55530, 9640, 8935, 18778,\n",
|
|
" 16323, 9590, 9521, 31720, 30842, 9638, 9388, 18623, 9414,\n",
|
|
" 65649, 9638, 9157, 9845, 28143, 9559, 9664, 9706, 9523,\n",
|
|
" 9632, 9297, 17730, 34907, 9633, 9521, 8888, 9647, 9706,\n",
|
|
" 19105, 10386, 9641, 9655, 31503, 9208, 28911, 16617, 9638,\n",
|
|
" 9604, 38631, 29805, 9773, 22333, 9098, 9043, 9318, 61250,\n",
|
|
" 9559, 9485, 18784, 9633, 9339, 9557, 8888, 9074, 17196,\n",
|
|
" 16439, 9410, 17138, 9638, 125, 9568, 119383, 9202, 123,\n",
|
|
" 9619, 9202, 9765, 107693, 9960, 9546, 14646, 12508, 12692,\n",
|
|
" 9202, 122, 9619, 9202, 9583, 17342, 12424, 9043, 9966,\n",
|
|
" 21614, 9633, 9656, 9529, 9056, 119, 102]],\n",
|
|
" device='cuda:0'),\n",
|
|
" 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0],\n",
|
|
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
|
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
|
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
|
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
|
" 1]], device='cuda:0'),\n",
|
|
" 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0]], device='cuda:0')}"
|
|
]
|
|
},
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs, labels = my_collate_fn(datasetTrain[100:102])\n",
|
|
"inputs = {k:v.to(device) for k,v in inputs.items()}\n",
|
|
"inputs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"id": "6b31782c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.eval()\n",
|
|
"with torch.no_grad():\n",
|
|
" predict_label = model(**inputs)\n",
|
|
" sp = model.softmax(predict_label)\n",
|
|
" p = sp.argmax(dim=-1,keepdim=True).squeeze().cpu()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"id": "7f4d43ce",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n",
|
|
"['O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DT', 'I-DT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for data in p.numpy():\n",
|
|
" print(tagIdConverter.convert_ids_to_tokens(data))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"id": "5ade3317",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n",
|
|
"['O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DT', 'I-DT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for data in labels.squeeze().numpy():\n",
|
|
" print(tagIdConverter.convert_ids_to_tokens(data))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b48e22ff",
|
|
"metadata": {},
|
|
"source": [
|
|
"It work!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"id": "383dd24a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([194, 1])"
|
|
]
|
|
},
|
|
"execution_count": 33,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sp.cpu().view(-1,sp.size(-1)).argmax(dim=-1,keepdim=True).size()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"id": "ff74fced",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(120)"
|
|
]
|
|
},
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"correct = (sp.cpu().view(-1,sp.size(-1)).argmax(dim=-1) == labels.view(-1)).sum()\n",
|
|
"correct"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"id": "3f6ad5d8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(120, device='cuda:0')"
|
|
]
|
|
},
|
|
"execution_count": 35,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs[\"attention_mask\"].view(-1).sum()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "986fd52b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(1., device='cuda:0')"
|
|
]
|
|
},
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"accuracy = correct / inputs[\"attention_mask\"].view(-1).sum()\n",
|
|
"accuracy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8cc47437",
|
|
"metadata": {},
|
|
"source": [
|
|
"accuracy는 다음과 같이 구해져요."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"id": "1f3f8666",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|█████████████████████████████████████████████████████████████████████████████| 125/125 [00:01<00:00, 72.97batch/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.eval()\n",
|
|
"collect_list = []\n",
|
|
"with torch.no_grad():\n",
|
|
" with tqdm(test_loader, unit=\"batch\") as tepoch:\n",
|
|
" for batch_i,batch_l in tepoch:\n",
|
|
" batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n",
|
|
" batch_labels = batch_l.cuda(device)\n",
|
|
" output = model(**batch_inputs)\n",
|
|
" loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n",
|
|
" \n",
|
|
" prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n",
|
|
" correct = (prediction == batch_labels.view(-1)).sum().item()\n",
|
|
" accuracy = correct / batch_inputs[\"attention_mask\"].view(-1).sum()\n",
|
|
" \n",
|
|
" collect_list.append({\"loss\":loss.item(),\"accuracy\":accuracy, \"batch_size\":batch_labels.size(0),\n",
|
|
" \"predict\":output.argmax(dim=-1).cpu(),\n",
|
|
" \"actual\":batch_labels.cpu(),\n",
|
|
" \"attention_mask\":batch_inputs[\"attention_mask\"].cpu()})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "b7567f48",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def getConfusionMatrix(predict,actual,attention_mask):\n",
|
|
" ret = torch.zeros((22,22),dtype=torch.long)\n",
|
|
" for i,(p_s,a_s) in enumerate(zip(predict,actual)):\n",
|
|
" for j,(p,a) in enumerate(zip(p_s,a_s)):\n",
|
|
" ret[p,a] += attention_mask[i,j]\n",
|
|
" return ret"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a898cd34",
|
|
"metadata": {},
|
|
"source": [
|
|
"단순하게 confusion matrix를 계산하는 함수. 클래스 22개 정해져 있음."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "15cd73a5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])"
|
|
]
|
|
},
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"getConfusionMatrix(torch.tensor([[0,1]]),torch.tensor([[1,1]]),torch.tensor([[1,1]]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"id": "de9c7932",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"average_loss : 0.16819471586681903, average_accuracy : 0.9595034718513489, size :500\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"total_loss = 0\n",
|
|
"total_accuracy = 0\n",
|
|
"total_size = 0\n",
|
|
"confusion = torch.zeros((22,22),dtype=torch.long)\n",
|
|
"\n",
|
|
"for item in collect_list:\n",
|
|
" batch_size = item[\"batch_size\"]\n",
|
|
" total_loss += batch_size * item[\"loss\"]\n",
|
|
" total_accuracy += batch_size * item[\"accuracy\"]\n",
|
|
" total_size += batch_size\n",
|
|
" confusion += getConfusionMatrix(item[\"predict\"],item[\"actual\"],item[\"attention_mask\"])\n",
|
|
"print(f\"\"\"average_loss : {total_loss/total_size}, average_accuracy : {total_accuracy/total_size}, size :{total_size}\"\"\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "24539b98",
|
|
"metadata": {},
|
|
"source": [
|
|
"test로 보면 결과가 나왔어요. 96% 나와요. F1 스코어는 아직입니다."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"id": "f6047991",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"confusion = confusion[0:21][0:21]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "aeed90e3",
|
|
"metadata": {},
|
|
"source": [
|
|
"Outside 토큰에 해당하는 곳을 짜르겠습니다."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"id": "000d1e68",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "NameError",
|
|
"evalue": "name 'plt' is not defined",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_13736/3551859736.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mitertools\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtitle\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"confusion matrix\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfusion\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcmap\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'Blues'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
"\u001b[1;31mNameError\u001b[0m: name 'plt' is not defined"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import itertools\n",
|
|
"\n",
|
|
"plt.title(\"confusion matrix\")\n",
|
|
"plt.imshow(confusion,cmap='Blues')\n",
|
|
"\n",
|
|
"plt.colorbar()\n",
|
|
"for i,j in itertools.product(range(confusion.shape[0]),range(confusion.shape[1])):\n",
|
|
" plt.text(j,i,\"{:}\".format(confusion[i,j]),horizontalalignment=\"center\",color=\"black\" if i == j else \"black\")\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "543ede4b",
|
|
"metadata": {},
|
|
"source": [
|
|
"혼동행렬을 보면 별로인것 같습니다.\n",
|
|
"왼쪽이 predict이고 밑이 actual입니다."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9e85493d",
|
|
"metadata": {},
|
|
"source": [
|
|
"$$Precision(class = a) = \\frac{TP(class = a)}{TP(class = a)+FP(class = a)}$$\n",
|
|
"\n",
|
|
"$$Recall(class = a) = \\frac{TP(class = a)}{TP(class = a)+FN(class = a)}$$\n",
|
|
"\n",
|
|
"$$F1Score(class = a) = \\frac{2}{\\frac{1}{Precision(class = a)}+\\frac{1}{Recall(class = a)} }$$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "200c95b9",
|
|
"metadata": {},
|
|
"source": [
|
|
"F1Score는 다음과 같이 주어집니다."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0b23e7d5",
|
|
"metadata": {},
|
|
"source": [
|
|
"다른 클래스에 대해서도 모두 해보자"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 75,
|
|
"id": "38b3eee6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def getF1Score(confusion,c):\n",
|
|
" TP = confusion[c,c]\n",
|
|
" FP = confusion[c].sum() - TP\n",
|
|
" FN = confusion[:,c].sum() - TP\n",
|
|
" precision = TP / (TP + FP)\n",
|
|
" recall = TP / (TP + FN)\n",
|
|
"\n",
|
|
" f1Score = (2*precision*recall)/(precision + recall)\n",
|
|
" return f1Score"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 77,
|
|
"id": "61fe2d6c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"class 0 f1 score : nan\n",
|
|
"class 1 f1 score : nan\n",
|
|
"class 2 f1 score : nan\n",
|
|
"class 3 f1 score : nan\n",
|
|
"class 4 f1 score : 0.9583332538604736\n",
|
|
"class 5 f1 score : 0.9216590523719788\n",
|
|
"class 6 f1 score : 0.9232480525970459\n",
|
|
"class 7 f1 score : 0.9203747510910034\n",
|
|
"class 8 f1 score : 0.8780487179756165\n",
|
|
"class 9 f1 score : nan\n",
|
|
"class 10 f1 score : nan\n",
|
|
"class 11 f1 score : nan\n",
|
|
"class 12 f1 score : nan\n",
|
|
"class 13 f1 score : nan\n",
|
|
"class 14 f1 score : 0.9240506887435913\n",
|
|
"class 15 f1 score : 0.7439999580383301\n",
|
|
"class 16 f1 score : 0.885114848613739\n",
|
|
"class 17 f1 score : 0.9293712377548218\n",
|
|
"class 18 f1 score : 0.932692289352417\n",
|
|
"class 19 f1 score : nan\n",
|
|
"class 20 f1 score : nan\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for i in range(21):\n",
|
|
" f1 = getF1Score(confusion,i)\n",
|
|
" print(f\"class {i} f1 score : {f1}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7b59fc83",
|
|
"metadata": {},
|
|
"source": [
|
|
"nan 나온 것에 대해서 생각해보자."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 79,
|
|
"id": "ba58322f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import itertools\n",
|
|
"import collections"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 120,
|
|
"id": "a5762fd2",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"5000it [00:00, 90912.13it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"token \t count frequency%\n",
|
|
"O(21) \t 174832 79.596%\n",
|
|
"I-PS(17) \t 12555 5.716%\n",
|
|
"I-OG(16) \t 10927 4.975%\n",
|
|
"B-PS(7) \t 4726 2.152%\n",
|
|
"I-DT(14) \t 4407 2.006%\n",
|
|
"B-OG(6) \t 3782 1.722%\n",
|
|
"I-LC(15) \t 2365 1.077%\n",
|
|
"B-DT(4) \t 2338 1.064%\n",
|
|
"B-LC(5) \t 2217 1.009%\n",
|
|
"I-TI(18) \t 1030 0.469%\n",
|
|
"B-TI(8) \t 397 0.181%\n",
|
|
"I-목소(19) \t 32 0.015%\n",
|
|
"I-(11) \t 15 0.007%\n",
|
|
"I-조선(20) \t 8 0.004%\n",
|
|
"I-1(12) \t 5 0.002%\n",
|
|
"B-(1) \t 4 0.002%\n",
|
|
"I-<휠(13) \t 4 0.002%\n",
|
|
"B-조선(10) \t 1 0.000%\n",
|
|
"B-목소(9) \t 1 0.000%\n",
|
|
"B-<휠(3) \t 1 0.000%\n",
|
|
"B-1(2) \t 1 0.000%\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"tagIdConverter = TagIdConverter()\n",
|
|
"counter = collections.Counter()\n",
|
|
"total_l = 0\n",
|
|
"\n",
|
|
"for item in tqdm(itertools.chain(datasetTrain,datasetDev,datasetTest)):\n",
|
|
" entities = item[\"entity\"]\n",
|
|
" for entity in entities:\n",
|
|
" counter[entity] += 1\n",
|
|
" total_l += len(entities)\n",
|
|
"print(f\"{'token':<12}\\t{'count':>12} {'frequency%':>12}\")\n",
|
|
"for token,count in counter.most_common():\n",
|
|
" tid = tagIdConverter.convert_tokens_to_ids([token])[0]\n",
|
|
" print(f\"{f'{token}({tid})':<12}\\t{count:>12}{count*100/total_l:>12.3f}%\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9f068a82",
|
|
"metadata": {},
|
|
"source": [
|
|
"19, 11, 20, 12, 1, 13, 10, 9, 3, 2 번은 데이터 규모에 비해서 유의미한 데이터가 아니다. 샘플이 너무 적어서 학습하기에 부적절하다.\n",
|
|
"그래서 한번도 출현을 안해서 0이 나왔다."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "38fc05df",
|
|
"metadata": {},
|
|
"source": [
|
|
"모델을 저장해보자"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 121,
|
|
"id": "54950483",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"torch.save(model.state_dict(), \"model.zip\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b8e9c3a4",
|
|
"metadata": {},
|
|
"source": [
|
|
"다음과 같이 하면 저장됨."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "ee39908f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<All keys matched successfully>"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.load_state_dict(torch.load(\"model.zip\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "37a2ba2d",
|
|
"metadata": {},
|
|
"source": [
|
|
"로딩은 다음과 같이"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "16bd6dff",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|