{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "1cc9cec9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sys\n", "sys.executable" ] }, { "cell_type": "markdown", "id": "f9c786f1", "metadata": {}, "source": [ "파이썬 환경 확인.\n", "envs\\\\nn\\\\python.exe 으로 끝나기를 기대합니다" ] }, { "cell_type": "code", "execution_count": 2, "id": "4c0a08d7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 9.06it/s]\n" ] } ], "source": [ "from time import sleep\n", "from tqdm import tqdm, trange\n", "\n", "lst = [i for i in range(5)]\n", "\n", "for element in tqdm(lst):\n", " sleep(0.1)" ] }, { "cell_type": "markdown", "id": "56a66a44", "metadata": {}, "source": [ "### tqdm 소개\n", "tqdm은 다음과 같이 Progress bar 그려주는 라이브러리이에요. 와! 편하다.\n", "```\n", "from tqdm.auto import tqdm\n", "from tqdm.notebook import tqdm\n", "```\n", "이건 0에서 멈춰이있고 작동하지 않더라고요. 왜인진 몰라요." ] }, { "cell_type": "code", "execution_count": 3, "id": "d00b25e6", "metadata": {}, "outputs": [], "source": [ "from preprocessing import readPreporcssedDataAll\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "from dataset import make_collate_fn, DatasetArray\n", "from transformers import BertTokenizer\n", "import torch.nn as nn\n", "from read_data import TagIdConverter" ] }, { "cell_type": "markdown", "id": "a1d7bb15", "metadata": {}, "source": [ "대충 필요한 것 임포트" ] }, { "cell_type": "code", "execution_count": 4, "id": "433419f4", "metadata": {}, "outputs": [], "source": [ "tagIdConverter = TagIdConverter()" ] }, { "cell_type": "code", "execution_count": 5, "id": "79fb54df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda available : True\n", "available device count : 1\n", "device name: NVIDIA GeForce RTX 3070\n" ] } ], "source": [ "print(\"cuda available :\",torch.cuda.is_available())\n", "print(\"available device count :\",torch.cuda.device_count())\n", "\n", "if torch.cuda.is_available():\n", " device_index = torch.cuda.current_device()\n", " print(\"device name:\",torch.cuda.get_device_name(device_index))" ] }, { "cell_type": "markdown", "id": "dd22dd5e", "metadata": {}, "source": [ "cuda가 가능한지 먼저 확인해보아요." ] }, { "cell_type": "code", "execution_count": 7, "id": "2cd9fe37", "metadata": {}, "outputs": [], "source": [ "PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n", "tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)\n", "\n", "my_collate_fn = make_collate_fn(tokenizer, tagIdConverter)" ] }, { "cell_type": "markdown", "id": "2177c793", "metadata": {}, "source": [ "데이터 로딩 준비" ] }, { "cell_type": "code", "execution_count": 7, "id": "e738062d", "metadata": {}, "outputs": [], "source": [ "from transformers import BertModel" ] }, { "cell_type": "code", "execution_count": 8, "id": "70296ee3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n", "bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)" ] }, { "cell_type": "code", "execution_count": 9, "id": "6cbc236d", "metadata": {}, "outputs": [], "source": [ "class MyModel(nn.Module):\n", " def __init__(self,output_feat: int,bert):\n", " super().__init__()\n", " self.bert = bert\n", " self.dropout = nn.Dropout(p=0.1)\n", " self.lin = nn.Linear(768,output_feat) #[batch_size,word_size,768] -> [batch_size,word_size,output_feat]\n", " self.softmax = nn.Softmax(2) #[batch_size,word_size,output_feat] -> [batch_size,word_size,output_feat]\n", " #0부터 시작해서 2 번째 차원에 softmax.\n", "\n", " def forward(self,**kargs):\n", " emb = self.bert(**kargs)\n", " e = self.dropout(emb['last_hidden_state'])\n", " w = self.lin(e)\n", " return w" ] }, { "cell_type": "markdown", "id": "7dec337a", "metadata": {}, "source": [ "모델 선언\n", "`nn.CrossEntropy`는 소프트맥스를 겸함." ] }, { "cell_type": "code", "execution_count": 10, "id": "e5214de8", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MyModel(\n", " (bert): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(119547, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (1): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (3): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (4): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (5): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (6): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (7): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (8): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (9): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (10): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (11): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", " )\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin): Linear(in_features=768, out_features=22, bias=True)\n", " (softmax): Softmax(dim=2)\n", ")\n" ] } ], "source": [ "model = MyModel(22,bert)\n", "model.cuda()\n", "bert.cuda()\n", "print(model)" ] }, { "cell_type": "markdown", "id": "72df7ac3", "metadata": {}, "source": [ "Tag의 종류가 22가지 입니다. 그래서 22을 넣었어요." ] }, { "cell_type": "markdown", "id": "0bba477c", "metadata": {}, "source": [ "생성과 동시에 gpu로 옮기자.\n", "`cuda` 저거 호출하면 됨." ] }, { "cell_type": "code", "execution_count": 11, "id": "5aa129e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bert current device : cuda:0\n" ] } ], "source": [ "print(\"bert current device :\",bert.device)" ] }, { "cell_type": "code", "execution_count": 12, "id": "61d356b6", "metadata": {}, "outputs": [], "source": [ "#for param in bert.parameters():\n", "# param.requires_grad = False" ] }, { "cell_type": "markdown", "id": "cf6690b2", "metadata": {}, "source": [ "bert 는 업데이트 하지 않는다. 메모리를 아낄 수 있다." ] }, { "cell_type": "code", "execution_count": 26, "id": "f28a1a61", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device(\"cuda\")\n", "device" ] }, { "cell_type": "code", "execution_count": 28, "id": "8844beef", "metadata": {}, "outputs": [], "source": [ "inputs = {'input_ids': torch.tensor([[ 101, 39671, 8935, 73380, 30842, 9632, 125, 9998, 9251, 9559,\n", " 9294, 8932, 28143, 9952, 8872, 127, 9489, 34907, 9952, 9279,\n", " 12424, 102]],device=device), \n", " 'token_type_ids': torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],device=device), \n", " 'attention_mask': torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],device=device)}" ] }, { "cell_type": "markdown", "id": "4046945a", "metadata": {}, "source": [ "적당한 인풋을 정의한다" ] }, { "cell_type": "code", "execution_count": 29, "id": "c37c3c1b", "metadata": {}, "outputs": [], "source": [ "emb = model(**inputs)" ] }, { "cell_type": "code", "execution_count": 30, "id": "261d4cc7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 22, 22])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "emb.size()" ] }, { "cell_type": "markdown", "id": "d492e37b", "metadata": {}, "source": [ "결과가 잘 나왔어요." ] }, { "cell_type": "code", "execution_count": 31, "id": "c773fdba", "metadata": {}, "outputs": [], "source": [ "entity_ids = torch.tensor([21,21,7,17,17,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21,21]\n", " ,dtype=torch.int64,device=device)" ] }, { "cell_type": "markdown", "id": "b0b69822", "metadata": {}, "source": [ "잘 이해하지는 못하겠는데, int64면 실행이 되고 int32이면 실행이 안된다.\n", "```\n", "RuntimeError: \"nll_loss_forward_reduce_cuda_kernel_2d_index\" not implemented for 'Int'\n", "```\n", "이런 오류를 내면서 죽음." ] }, { "cell_type": "code", "execution_count": 32, "id": "0f97b71a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([22, 22])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "emb.view(-1,emb.size(-1)).size()" ] }, { "cell_type": "code", "execution_count": 36, "id": "e18b33a5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([19, 13, 21, 21, 19, 21, 13, 13, 16, 16, 21, 16, 19, 13, 2, 13, 13, 16,\n", " 13, 20, 19, 19], device='cuda:0')" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict = emb.view(-1,emb.size(-1)).argmax(dim=-1)\n", "predict" ] }, { "cell_type": "code", "execution_count": 39, "id": "3312e199", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([False, False, False, False, False, True, False, False, False, False,\n", " True, False, False, False, False, False, False, False, False, False,\n", " False, False], device='cuda:0')" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(predict == entity_ids)" ] }, { "cell_type": "code", "execution_count": 38, "id": "97270036", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n", " device='cuda:0')" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(predict == entity_ids) * inputs[\"attention_mask\"]" ] }, { "cell_type": "code", "execution_count": 19, "id": "d7d0164a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.9564, device='cuda:0', grad_fn=)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)(emb.view(-1,emb.size(-1)),entity_ids)" ] }, { "cell_type": "markdown", "id": "9fee44d6", "metadata": {}, "source": [ "크로스 엔트로피를 계산하는 데에 성공.\n", "ignore_index는 padding class를 넣어요." ] }, { "cell_type": "code", "execution_count": 40, "id": "197cfa62", "metadata": {}, "outputs": [], "source": [ "del inputs\n", "del entity_ids\n", "del emb" ] }, { "cell_type": "markdown", "id": "733cb548", "metadata": {}, "source": [ "본격적으로 학습시켜봅시다." ] }, { "cell_type": "code", "execution_count": 16, "id": "40a3e52c", "metadata": {}, "outputs": [], "source": [ "datasetTrain, datasetDev, datasetTest = readPreporcssedDataAll()" ] }, { "cell_type": "markdown", "id": "66d41fee", "metadata": {}, "source": [ "데이터 셋이 적어도 어느정도 성능이 나와하야 할지 생각해봅시다.\n", "`O` 토큰으로 범벅이 되있으니 전부 `O`로 찍는 것 보다 좋은 성능이 나와야 하지 않겠어요?\n", "한번 시도해봅시다." ] }, { "cell_type": "code", "execution_count": 17, "id": "80c37a04", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████| 4250/4250 [00:00<00:00, 283367.38it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "151572/190488 = 0.7957036663726849\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "total_l = 0\n", "total_o = 0\n", "\n", "for item in tqdm(datasetTrain):\n", " entities = item[\"entity\"]\n", " l = len(entities)\n", " o = sum(map(lambda x: 1 if x == \"O\" else 0,entities))\n", " total_l += l\n", " total_o += o\n", "\n", "print(f\"{total_o}/{total_l} = {total_o/total_l}\")" ] }, { "cell_type": "markdown", "id": "c0b67c41", "metadata": {}, "source": [ "79% 보다 높아야 해요." ] }, { "cell_type": "code", "execution_count": 18, "id": "619b959f", "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 4\n", "train_loader = DataLoader(\n", " DatasetArray(datasetTrain),\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " collate_fn=my_collate_fn\n", ")\n", "dev_loader = DataLoader(\n", " DatasetArray(datasetDev),\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " collate_fn=my_collate_fn\n", ")\n", "test_loader = DataLoader(\n", " DatasetArray(datasetTest),\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " collate_fn=my_collate_fn\n", ")" ] }, { "cell_type": "markdown", "id": "7d45dd29", "metadata": {}, "source": [ "BATCH_SIZE 를 4로 잡는다.\n", "bert paramter를 freeze 안했을땐 batch를 8 정도로 했어요. 그 이상은 메모리가 부족해서 돌아가지 않아요.\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "efd1837a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({'input_ids': tensor([[ 101, 39671, 8935, 73380, 30842, 9632, 125, 9998, 9251, 9559,\n", " 9294, 8932, 28143, 9952, 8872, 127, 9489, 34907, 9952, 9279,\n", " 12424, 102]]),\n", " 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n", " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])},\n", " tensor([[21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,\n", " 21, 21, 21, 21]]))" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_collate_fn(datasetTrain[0:1])" ] }, { "cell_type": "markdown", "id": "c4a2e2e1", "metadata": {}, "source": [ "데이터를 한번 더 확인" ] }, { "cell_type": "code", "execution_count": 20, "id": "c575cb55", "metadata": {}, "outputs": [], "source": [ "from torch.optim import AdamW" ] }, { "cell_type": "markdown", "id": "627cb2f8", "metadata": {}, "source": [ "tqdm 확인" ] }, { "cell_type": "code", "execution_count": 21, "id": "56844773", "metadata": {}, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=1.0e-5)\n", "CELoss = nn.CrossEntropyLoss(ignore_index=tagIdConverter.pad_id)" ] }, { "cell_type": "markdown", "id": "eaa08ab2", "metadata": {}, "source": [ "옵티마이져 준비" ] }, { "cell_type": "code", "execution_count": 22, "id": "41b62321", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1, 2, 3, 4]\n", "[5, 6, 7, 8]\n" ] } ], "source": [ "from groupby_index import groupby_index\n", "\n", "for g in groupby_index([1,2,3,4,5,6,7,8],4):\n", " print([*g])" ] }, { "cell_type": "markdown", "id": "faf180cb", "metadata": {}, "source": [ "`groupby_index` 그룹으로 묶어서 실행" ] }, { "cell_type": "code", "execution_count": 43, "id": "109259b4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 start:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 0: 100%|███████████████████████████████████████| 1063/1063 [00:45<00:00, 23.15batch/s, accuracy=0.923, loss=2.24]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch 1 start:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1: 100%|███████████████████████████████████████| 1063/1063 [00:46<00:00, 23.07batch/s, accuracy=0.961, loss=1.52]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch 2 start:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2: 100%|██████████████████████████████████████| 1063/1063 [00:46<00:00, 23.06batch/s, accuracy=0.976, loss=0.793]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch 3 start:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 3: 100%|███████████████████████████████████████| 1063/1063 [00:46<00:00, 23.07batch/s, accuracy=0.935, loss=1.88]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch 4 start:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 4: 100%|███████████████████████████████████████| 1063/1063 [00:46<00:00, 23.09batch/s, accuracy=0.98, loss=0.524]\n" ] } ], "source": [ "TRAIN_EPOCH = 5\n", "\n", "result = []\n", "iteration = 0\n", "\n", "t = []\n", "\n", "model.zero_grad()\n", "\n", "for epoch in range(TRAIN_EPOCH):\n", " model.train()\n", " print(f\"epoch {epoch} start:\")\n", " with tqdm(train_loader, unit=\"minibatch\") as tepoch:\n", " tepoch.set_description(f\"Epoch {epoch}\")\n", " \n", " for batch in groupby_index(tepoch,8):\n", " corrects = 0\n", " totals = 0\n", " losses = 0\n", " \n", " optimizer.zero_grad()\n", " for mini_i,mini_l in batch:\n", " batch_inputs = {k: v.cuda(device) for k, v in list(mini_i.items())}\n", " batch_labels = mini_l.cuda(device)\n", " attention_mask = batch_inputs[\"attention_mask\"]\n", " \n", " output = model(**batch_inputs)\n", " loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n", " \n", " prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n", " corrects += ((prediction == batch_labels.view(-1)) * attention_mask.view(-1)).sum().item()\n", " totals += attention_mask.view(-1).sum().item()\n", " losses += loss.item()\n", " loss.backward()\n", "\n", " optimizer.step()\n", " accuracy = corrects / totals\n", " result.append({\"iter\":iteration,\"loss\":losses,\"accuracy\":accuracy})\n", " tepoch.set_postfix(loss=losses, accuracy= accuracy)\n", " iteration += 1" ] }, { "cell_type": "code", "execution_count": 44, "id": "f0d9b2d7", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 45, "id": "19ca6da1", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "id": "0bee685c", "metadata": {}, "outputs": [], "source": [ "iters = [item[\"iter\"] for item in result]\n", "fig, ax1 = plt.subplots()\n", "ax1.plot(iters,[item[\"loss\"] for item in result],'g')\n", "ax2 = ax1.twinx()\n", "ax2.plot(iters,[item[\"accuracy\"] for item in result],'r')\n", "plt.xlabel(\"iter\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "2bb67740", "metadata": {}, "source": [ "학습 그래프입니다." ] }, { "cell_type": "code", "execution_count": 47, "id": "0defca72", "metadata": {}, "outputs": [], "source": [ "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 23, "id": "2f45cae0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gpu allocated : 678 MB\n", "gpu reserved : 1446MB\n" ] } ], "source": [ "print(f\"gpu allocated : {torch.cuda.memory_allocated() // 1024**2} MB\")\n", "print(f\"gpu reserved : {torch.cuda.memory_reserved() // 1024 ** 2}MB\")" ] }, { "cell_type": "markdown", "id": "68f18f4d", "metadata": {}, "source": [ "gpu 메모리 사용량을 보는 코드" ] }, { "cell_type": "code", "execution_count": 49, "id": "73b01630", "metadata": {}, "outputs": [], "source": [ "del batch_inputs\n", "del batch_labels\n", "del loss\n", "del optimizer" ] }, { "cell_type": "markdown", "id": "02a3f367", "metadata": {}, "source": [ "한번 테스트 해보자" ] }, { "cell_type": "code", "execution_count": 24, "id": "af93a3ec", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "김광현 을 둘러싼 주변 의 우려 는 너무 많이 던진다는 것 .\n", "두산 은 주포 인 김동주 와 안경현 이 빠진 상황 이 라 타력 에 적 지 않 은 문제점 을 안 고 있 지만 29 일 잠실 롯데전 이 우천 으로 취소 되 는 바람 에 시간 을 벌 었 고 더구나 삼성 이 4 연패 로 2 위 로 추락 해 어부지리 로 1 위 로 올라서 는 행운 을 잡 았 다 .\n" ] } ], "source": [ "for data in datasetTrain[100:102]:\n", " print(tokenizer.convert_tokens_to_string(data[\"tokens\"]))" ] }, { "cell_type": "code", "execution_count": 29, "id": "9b2cd5c4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[ 101, 8935, 118649, 30842, 9633, 9105, 30873, 119091, 9689,\n", " 118985, 9637, 9604, 26737, 9043, 9004, 32537, 47058, 9076,\n", " 65096, 11018, 8870, 119, 102, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0],\n", " [ 101, 9102, 21386, 9632, 9689, 55530, 9640, 8935, 18778,\n", " 16323, 9590, 9521, 31720, 30842, 9638, 9388, 18623, 9414,\n", " 65649, 9638, 9157, 9845, 28143, 9559, 9664, 9706, 9523,\n", " 9632, 9297, 17730, 34907, 9633, 9521, 8888, 9647, 9706,\n", " 19105, 10386, 9641, 9655, 31503, 9208, 28911, 16617, 9638,\n", " 9604, 38631, 29805, 9773, 22333, 9098, 9043, 9318, 61250,\n", " 9559, 9485, 18784, 9633, 9339, 9557, 8888, 9074, 17196,\n", " 16439, 9410, 17138, 9638, 125, 9568, 119383, 9202, 123,\n", " 9619, 9202, 9765, 107693, 9960, 9546, 14646, 12508, 12692,\n", " 9202, 122, 9619, 9202, 9583, 17342, 12424, 9043, 9966,\n", " 21614, 9633, 9656, 9529, 9056, 119, 102]],\n", " device='cuda:0'),\n", " 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0],\n", " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1]], device='cuda:0'),\n", " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0]], device='cuda:0')}" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs, labels = my_collate_fn(datasetTrain[100:102])\n", "inputs = {k:v.to(device) for k,v in inputs.items()}\n", "inputs" ] }, { "cell_type": "code", "execution_count": 30, "id": "6b31782c", "metadata": {}, "outputs": [], "source": [ "model.eval()\n", "with torch.no_grad():\n", " predict_label = model(**inputs)\n", " sp = model.softmax(predict_label)\n", " p = sp.argmax(dim=-1,keepdim=True).squeeze().cpu()" ] }, { "cell_type": "code", "execution_count": 31, "id": "7f4d43ce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n", "['O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DT', 'I-DT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n" ] } ], "source": [ "for data in p.numpy():\n", " print(tagIdConverter.convert_ids_to_tokens(data))" ] }, { "cell_type": "code", "execution_count": 32, "id": "5ade3317", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n", "['O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-DT', 'I-DT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-OG', 'I-OG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n" ] } ], "source": [ "for data in labels.squeeze().numpy():\n", " print(tagIdConverter.convert_ids_to_tokens(data))" ] }, { "cell_type": "markdown", "id": "b48e22ff", "metadata": {}, "source": [ "It work!" ] }, { "cell_type": "code", "execution_count": 33, "id": "383dd24a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([194, 1])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sp.cpu().view(-1,sp.size(-1)).argmax(dim=-1,keepdim=True).size()" ] }, { "cell_type": "code", "execution_count": 34, "id": "ff74fced", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(120)" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "correct = (sp.cpu().view(-1,sp.size(-1)).argmax(dim=-1) == labels.view(-1)).sum()\n", "correct" ] }, { "cell_type": "code", "execution_count": 35, "id": "3f6ad5d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(120, device='cuda:0')" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[\"attention_mask\"].view(-1).sum()" ] }, { "cell_type": "code", "execution_count": 36, "id": "986fd52b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1., device='cuda:0')" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy = correct / inputs[\"attention_mask\"].view(-1).sum()\n", "accuracy" ] }, { "cell_type": "markdown", "id": "8cc47437", "metadata": {}, "source": [ "accuracy는 다음과 같이 구해져요." ] }, { "cell_type": "code", "execution_count": 37, "id": "1f3f8666", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████████████████████████████████████| 125/125 [00:01<00:00, 72.97batch/s]\n" ] } ], "source": [ "model.eval()\n", "collect_list = []\n", "with torch.no_grad():\n", " with tqdm(test_loader, unit=\"batch\") as tepoch:\n", " for batch_i,batch_l in tepoch:\n", " batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n", " batch_labels = batch_l.cuda(device)\n", " output = model(**batch_inputs)\n", " loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))\n", " \n", " prediction = output.view(-1, output.size(-1)).argmax(dim=-1)\n", " correct = (prediction == batch_labels.view(-1)).sum().item()\n", " accuracy = correct / batch_inputs[\"attention_mask\"].view(-1).sum()\n", " \n", " collect_list.append({\"loss\":loss.item(),\"accuracy\":accuracy, \"batch_size\":batch_labels.size(0),\n", " \"predict\":output.argmax(dim=-1).cpu(),\n", " \"actual\":batch_labels.cpu(),\n", " \"attention_mask\":batch_inputs[\"attention_mask\"].cpu()})" ] }, { "cell_type": "code", "execution_count": 38, "id": "b7567f48", "metadata": {}, "outputs": [], "source": [ "def getConfusionMatrix(predict,actual,attention_mask):\n", " ret = torch.zeros((22,22),dtype=torch.long)\n", " for i,(p_s,a_s) in enumerate(zip(predict,actual)):\n", " for j,(p,a) in enumerate(zip(p_s,a_s)):\n", " ret[p,a] += attention_mask[i,j]\n", " return ret" ] }, { "cell_type": "markdown", "id": "a898cd34", "metadata": {}, "source": [ "단순하게 confusion matrix를 계산하는 함수. 클래스 22개 정해져 있음." ] }, { "cell_type": "code", "execution_count": 39, "id": "15cd73a5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "getConfusionMatrix(torch.tensor([[0,1]]),torch.tensor([[1,1]]),torch.tensor([[1,1]]))" ] }, { "cell_type": "code", "execution_count": 40, "id": "de9c7932", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "average_loss : 0.16819471586681903, average_accuracy : 0.9595034718513489, size :500\n" ] } ], "source": [ "total_loss = 0\n", "total_accuracy = 0\n", "total_size = 0\n", "confusion = torch.zeros((22,22),dtype=torch.long)\n", "\n", "for item in collect_list:\n", " batch_size = item[\"batch_size\"]\n", " total_loss += batch_size * item[\"loss\"]\n", " total_accuracy += batch_size * item[\"accuracy\"]\n", " total_size += batch_size\n", " confusion += getConfusionMatrix(item[\"predict\"],item[\"actual\"],item[\"attention_mask\"])\n", "print(f\"\"\"average_loss : {total_loss/total_size}, average_accuracy : {total_accuracy/total_size}, size :{total_size}\"\"\")" ] }, { "cell_type": "markdown", "id": "24539b98", "metadata": {}, "source": [ "test로 보면 결과가 나왔어요. 96% 나와요. F1 스코어는 아직입니다." ] }, { "cell_type": "code", "execution_count": 41, "id": "f6047991", "metadata": {}, "outputs": [], "source": [ "confusion = confusion[0:21][0:21]" ] }, { "cell_type": "markdown", "id": "aeed90e3", "metadata": {}, "source": [ "Outside 토큰에 해당하는 곳을 짜르겠습니다." ] }, { "cell_type": "code", "execution_count": 42, "id": "000d1e68", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'plt' is not defined", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_13736/3551859736.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mitertools\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtitle\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"confusion matrix\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfusion\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcmap\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'Blues'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mNameError\u001b[0m: name 'plt' is not defined" ] } ], "source": [ "import itertools\n", "\n", "plt.title(\"confusion matrix\")\n", "plt.imshow(confusion,cmap='Blues')\n", "\n", "plt.colorbar()\n", "for i,j in itertools.product(range(confusion.shape[0]),range(confusion.shape[1])):\n", " plt.text(j,i,\"{:}\".format(confusion[i,j]),horizontalalignment=\"center\",color=\"black\" if i == j else \"black\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "543ede4b", "metadata": {}, "source": [ "혼동행렬을 보면 별로인것 같습니다.\n", "왼쪽이 predict이고 밑이 actual입니다." ] }, { "cell_type": "markdown", "id": "9e85493d", "metadata": {}, "source": [ "$$Precision(class = a) = \\frac{TP(class = a)}{TP(class = a)+FP(class = a)}$$\n", "\n", "$$Recall(class = a) = \\frac{TP(class = a)}{TP(class = a)+FN(class = a)}$$\n", "\n", "$$F1Score(class = a) = \\frac{2}{\\frac{1}{Precision(class = a)}+\\frac{1}{Recall(class = a)} }$$" ] }, { "cell_type": "markdown", "id": "200c95b9", "metadata": {}, "source": [ "F1Score는 다음과 같이 주어집니다." ] }, { "cell_type": "markdown", "id": "0b23e7d5", "metadata": {}, "source": [ "다른 클래스에 대해서도 모두 해보자" ] }, { "cell_type": "code", "execution_count": 75, "id": "38b3eee6", "metadata": {}, "outputs": [], "source": [ "def getF1Score(confusion,c):\n", " TP = confusion[c,c]\n", " FP = confusion[c].sum() - TP\n", " FN = confusion[:,c].sum() - TP\n", " precision = TP / (TP + FP)\n", " recall = TP / (TP + FN)\n", "\n", " f1Score = (2*precision*recall)/(precision + recall)\n", " return f1Score" ] }, { "cell_type": "code", "execution_count": 77, "id": "61fe2d6c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "class 0 f1 score : nan\n", "class 1 f1 score : nan\n", "class 2 f1 score : nan\n", "class 3 f1 score : nan\n", "class 4 f1 score : 0.9583332538604736\n", "class 5 f1 score : 0.9216590523719788\n", "class 6 f1 score : 0.9232480525970459\n", "class 7 f1 score : 0.9203747510910034\n", "class 8 f1 score : 0.8780487179756165\n", "class 9 f1 score : nan\n", "class 10 f1 score : nan\n", "class 11 f1 score : nan\n", "class 12 f1 score : nan\n", "class 13 f1 score : nan\n", "class 14 f1 score : 0.9240506887435913\n", "class 15 f1 score : 0.7439999580383301\n", "class 16 f1 score : 0.885114848613739\n", "class 17 f1 score : 0.9293712377548218\n", "class 18 f1 score : 0.932692289352417\n", "class 19 f1 score : nan\n", "class 20 f1 score : nan\n" ] } ], "source": [ "for i in range(21):\n", " f1 = getF1Score(confusion,i)\n", " print(f\"class {i} f1 score : {f1}\")" ] }, { "cell_type": "markdown", "id": "7b59fc83", "metadata": {}, "source": [ "nan 나온 것에 대해서 생각해보자." ] }, { "cell_type": "code", "execution_count": 79, "id": "ba58322f", "metadata": {}, "outputs": [], "source": [ "import itertools\n", "import collections" ] }, { "cell_type": "code", "execution_count": 120, "id": "a5762fd2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "5000it [00:00, 90912.13it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "token \t count frequency%\n", "O(21) \t 174832 79.596%\n", "I-PS(17) \t 12555 5.716%\n", "I-OG(16) \t 10927 4.975%\n", "B-PS(7) \t 4726 2.152%\n", "I-DT(14) \t 4407 2.006%\n", "B-OG(6) \t 3782 1.722%\n", "I-LC(15) \t 2365 1.077%\n", "B-DT(4) \t 2338 1.064%\n", "B-LC(5) \t 2217 1.009%\n", "I-TI(18) \t 1030 0.469%\n", "B-TI(8) \t 397 0.181%\n", "I-목소(19) \t 32 0.015%\n", "I-(11) \t 15 0.007%\n", "I-조선(20) \t 8 0.004%\n", "I-1(12) \t 5 0.002%\n", "B-(1) \t 4 0.002%\n", "I-<휠(13) \t 4 0.002%\n", "B-조선(10) \t 1 0.000%\n", "B-목소(9) \t 1 0.000%\n", "B-<휠(3) \t 1 0.000%\n", "B-1(2) \t 1 0.000%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "counter = collections.Counter()\n", "total_l = 0\n", "\n", "for item in tqdm(itertools.chain(datasetTrain,datasetDev,datasetTest)):\n", " entities = item[\"entity\"]\n", " for entity in entities:\n", " counter[entity] += 1\n", " total_l += len(entities)\n", "print(f\"{'token':<12}\\t{'count':>12} {'frequency%':>12}\")\n", "for token,count in counter.most_common():\n", " tid = tagIdConverter.convert_tokens_to_ids([token])[0]\n", " print(f\"{f'{token}({tid})':<12}\\t{count:>12}{count*100/total_l:>12.3f}%\")" ] }, { "cell_type": "markdown", "id": "9f068a82", "metadata": {}, "source": [ "19, 11, 20, 12, 1, 13, 10, 9, 3, 2 번은 데이터 규모에 비해서 유의미한 데이터가 아니다. 샘플이 너무 적어서 학습하기에 부적절하다.\n", "그래서 한번도 출현을 안해서 0이 나왔다." ] }, { "cell_type": "markdown", "id": "38fc05df", "metadata": {}, "source": [ "모델을 저장해보자" ] }, { "cell_type": "code", "execution_count": 121, "id": "54950483", "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), \"model.zip\")" ] }, { "cell_type": "markdown", "id": "b8e9c3a4", "metadata": {}, "source": [ "다음과 같이 하면 저장됨." ] }, { "cell_type": "code", "execution_count": 11, "id": "ee39908f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.load_state_dict(torch.load(\"model.zip\"))" ] }, { "cell_type": "markdown", "id": "37a2ba2d", "metadata": {}, "source": [ "로딩은 다음과 같이" ] }, { "cell_type": "code", "execution_count": null, "id": "16bd6dff", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 5 }