nsmc-study/Training.ipynb

1749 lines
107 KiB
Plaintext
Raw Normal View History

2022-02-24 01:24:20 +09:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4c31f5ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sys\n",
"sys.executable"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2b9e11e7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"load bert tokenizer...\n"
]
}
],
"source": [
"from transformers import BertTokenizer\n",
"print(\"load bert tokenizer...\")\n",
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "82bf44a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda available : True\n",
"available device count : 1\n",
"device name: NVIDIA GeForce RTX 3070\n"
]
}
],
"source": [
"import torch\n",
"print(\"cuda available :\",torch.cuda.is_available())\n",
"print(\"available device count :\",torch.cuda.device_count())\n",
"if torch.cuda.is_available():\n",
" device_index = torch.cuda.current_device()\n",
" print(\"device name:\",torch.cuda.get_device_name(device_index))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "38dcf62d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"read train set\n",
"100%|██████████████████████████████████████████████████████████████████████| 150000/150000 [00:00<00:00, 208333.74it/s]\n",
"read test set\n",
"100%|████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 260420.34it/s]\n"
]
}
],
"source": [
"from ndataset import readNsmcDataAll, make_collate_fn\n",
"dataTrain, dataTest = readNsmcDataAll()\n",
"collate_fn = make_collate_fn(tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "650c8a19",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"from transformers import BertModel\n",
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)"
]
},
{
"cell_type": "markdown",
"id": "30d69b45",
"metadata": {},
"source": [
"BERT 로딩"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "7583b0d1",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self,bert):\n",
" super().__init__()\n",
" self.bert = bert\n",
" self.dropout = nn.Dropout(p=0.1)\n",
" self.lin1 = nn.Linear(768,256) #[batch_size,768] -> [batch_size,256]\n",
" self.lin2 = nn.Linear(256,1) #[batch_size,256] -> [batch_size,1]\n",
"\n",
" def forward(self,**kargs):\n",
" emb = self.bert(**kargs)\n",
" e1 = self.dropout(emb['pooler_output'])\n",
" e2 = self.lin1(e1)\n",
" w = self.lin2(e2)\n",
" return w.squeeze() #[batch_size]"
]
},
{
"cell_type": "markdown",
"id": "befe62b0",
"metadata": {},
"source": [
"모델 선언. 비슷하게 감."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "36585e76",
"metadata": {},
"outputs": [],
"source": [
"model = MyModel(bert)"
]
},
{
"cell_type": "markdown",
"id": "7969fead",
"metadata": {},
"source": [
"학습 과정에서 벌어지는 일"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8c2a4bc9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (lin1): Linear(in_features=768, out_features=256, bias=True)\n",
" (lin2): Linear(in_features=256, out_features=1, bias=True)\n",
")"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.cpu()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e027b926",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 768])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hidden = bert(**tokenizer([\"사랑해요.\",\"무서워요.\",\"슬퍼요.\",\"재미있어요.\"], return_tensors = 'pt', padding='longest'))['pooler_output']\n",
"hidden.size()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "ae9f8fba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.1623, 0.1365, 0.1949, 0.1491], grad_fn=<SqueezeBackward0>)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w = model.lin2(model.lin1(hidden)).squeeze()\n",
"w"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5470c3f8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.5405, 0.5341, 0.5486, 0.5372], grad_fn=<SigmoidBackward0>)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.sigmoid(w)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b7eb8e67",
"metadata": {},
"outputs": [],
"source": [
"labels = torch.tensor([1,0,0,1])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "7a324ed7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.6989, dtype=torch.float64,\n",
" grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nn.BCEWithLogitsLoss()(w,labels.double())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "cb54294d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ True, False, False, True])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(w > 0).long() == labels"
]
},
{
"cell_type": "markdown",
"id": "596b89bd",
"metadata": {},
"source": [
"이런 일이 벌어짐. sigmoid 는 나중에"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "769c4290",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (lin1): Linear(in_features=768, out_features=256, bias=True)\n",
" (lin2): Linear(in_features=256, out_features=1, bias=True)\n",
")\n"
]
}
],
"source": [
"device = torch.device('cuda')\n",
"model.to(device)\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b9380dcd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda', index=0)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bert.device"
]
},
{
"cell_type": "markdown",
"id": "5e82df0e",
"metadata": {},
"source": [
"모델을 모두 gpu로 보냄"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "74c4becc",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Dataset, DataLoader\n",
"BATCH_SIZE = 16\n",
"train_loader = DataLoader(\n",
" dataTrain,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=collate_fn\n",
")\n",
"test_loader = DataLoader(\n",
" dataTest,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=collate_fn\n",
")"
]
},
{
"cell_type": "markdown",
"id": "4153b2e7",
"metadata": {},
"source": [
"데이터 모델 준비"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "3cd5bf7b",
"metadata": {},
"outputs": [],
"source": [
"from torch.optim import AdamW\n",
"from groupby_index import groupby_index\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "65b5ccde",
"metadata": {},
"outputs": [],
"source": [
"optimizer = AdamW(model.parameters(), lr=1.0e-5)\n",
"BCELoss = nn.BCEWithLogitsLoss()"
]
},
{
"cell_type": "markdown",
"id": "79607e81",
"metadata": {},
"source": [
"학습 준비"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "4835a0d3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0: 100%|███████████████████████████████████| 9375/9375 [12:35<00:00, 12.41minibatch/s, accuracy=0.875, loss=2.58]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 1: 100%|███████████████████████████████████| 9375/9375 [12:35<00:00, 12.41minibatch/s, accuracy=0.898, loss=2.18]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 2 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 2: 1%|▎ | 82/9375 [00:06<12:30, 12.39minibatch/s, accuracy=0.867, loss=2.08]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_10708/1191029387.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 22\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmini_i\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mmini_l\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m \u001b[0mbatch_inputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m{\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mv\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmini_i\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 24\u001b[0m \u001b[0mbatch_labels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmini_l\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[0mattention_mask\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch_inputs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m\"attention_mask\"\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_10708/1191029387.py\u001b[0m in \u001b[0;36m<dictcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 22\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmini_i\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mmini_l\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m \u001b[0mbatch_inputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m{\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mv\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmini_i\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 24\u001b[0m \u001b[0mbatch_labels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmini_l\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[0mattention_mask\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch_inputs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m\"attention_mask\"\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"TRAIN_EPOCH = 5\n",
"\n",
"result = []\n",
"iteration = 0\n",
"\n",
"t = []\n",
"\n",
"model.zero_grad()\n",
"\n",
"for epoch in range(TRAIN_EPOCH):\n",
" model.train()\n",
" print(f\"epoch {epoch} start:\")\n",
" with tqdm(train_loader, unit=\"minibatch\") as tepoch:\n",
" tepoch.set_description(f\"Epoch {epoch}\")\n",
" \n",
" for batch in groupby_index(tepoch,8):\n",
" corrects = 0\n",
" totals = 0\n",
" losses = 0\n",
" \n",
" optimizer.zero_grad()\n",
" for mini_i,mini_l in batch:\n",
" batch_inputs = {k: v.to(device) for k, v in list(mini_i.items())}\n",
" batch_labels = mini_l.to(device)\n",
" \n",
" output = model(**batch_inputs)\n",
" loss = BCELoss(output, batch_labels.double())\n",
" \n",
" prediction = (output > 0).to(device,dtype=torch.int64)\n",
" corrects += (prediction == batch_labels).sum().item()\n",
" totals += prediction.size()[0]\n",
" losses += loss.item()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" accuracy = corrects / totals\n",
" result.append({\"iter\":iteration,\"loss\":losses,\"accuracy\":accuracy})\n",
" tepoch.set_postfix(loss=losses, accuracy= accuracy)\n",
" iteration += 1"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "81b69931",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "c3a73c68",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAD4CAYAAAAKA1qZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAABU1UlEQVR4nO2dd7zURNfHf+c2uPTekd6ld1REBQQsKCrFgqKiCIj4CthARBQRHwQLPGAFFeURFRUFRRGxoRQFadJVepcOt837x2x2k2zKJJvdze6d7/3s56ZMZibZzTkzZ86cIcYYJBKJRJI/SYl3BSQSiUQSP6QSkEgkknyMVAISiUSSj5FKQCKRSPIxUglIJBJJPiYtGpmmpKSwzMzMaGQtkUgkScmZM2cYYyzmDfOoKIHMzEycPn06GllLJBJJUkJEZ+NRrjQHSSQSST5GKgGJRCLJx0glIJFIJPkYqQQkEokkHyOVgEQikeRjpBKQSCSSfIxUAhKJRJKP8ZUSWLt/LX7e9XO8qyGRSPzMp58C+/fHuxZJA0VjPYHChQszN5PFaBwBANhYucaBRCIxIDsbyMgA6tUD/vwz3rXxFCI6wxgrHOtyfdMTyGN58a6CRCLxO0qjdfv2+NYjifCNEiBQcHvK8ilxrIkk6ZgxA/jnH+D774Evv4x3bYw5dAiYMiUk5BKJzz8Hfo6xGTcvxo1G5TdkBGPA888D//4b0yp5hS/NQYA0CUk84uhRoHRpbj7YvJkf86Og7dEDWLQIWLkSaNUq3rVxBgXe21g813PngMxMXmasFMGxY0CpUkCdOsCWLeHnlywBOncG+vYF3n/fdTH53hwkkUSFnBz+/8iR+NbDDqUVef58XKvhexRFE0tFriibw4eNzyvfWYL2BHylBCZ3nRzcPpV1Ko41kSQdRPZpRDl7FnjgAeD4ce/yTFQOHbJPs2oVMHmyfToRvBL+jAFPPQVs2mSe5plngPXr7Xs6c+fy/7E2UXmEr5RA/6b9g9szVs2IY00kEgtefRV46SVg4kTv8oylScVLhg61T9O6NTBihDflefV8jh8Hxo4FLrvM+Pz588Do0UCHDvYNiHfe8aZOccJXSqBMoTLB7ePnZCtL4lNOnOD/0zxcjsPLnkosycqKbXleKQGl1W5nflOftys70RR4AF8pATWr962OdxXyB88+G9+WzL//Aj17AgcPus/jnXeACRMir8uZM0CvXsZeILt28VZhly7Avn382LRpkZWXnQ306QNs3Bg69sEHoe3584Fbbw0Jl5Mn+bNSytfz5pvAf/4TWZ30/PQTMHCgNwIuN5f/Z4zn+dNPxun+/JM/l6ws3oNYuDB0TqQex4/z57R/P7B1K9CtG9C9e+i5/fYbMGAA3zZTvko5WVnAt99qjzEG3HUX8Msv2mu+/tq+bn6EMeb5p1ChQswtE76fwPAkGJ6E6zwkDuA/6fiV/9xzvPwRI9znYXUP+/fzc2XL2t/re+/x8336hJ+79dbQ9amp3jy35ct5Hm3aMHbJJeF5pqTw/XPn+P706Xx/0CDj/KLxXSr3mpVlfP766+3LVc6fPcv3z57l+wUKGKdv356f//HH8LyPH7cvb+pUfn7YMMYuvzyUXnlulSuHjpUsaZzHmTOhNMqnWDF+7t9/+X7Rotr7AxjLzTWvlw0ATrMoyGO7j+96Ag91eCjeVZDEg2ibQ0TyTwm8DrHq1ivlEBnXTzFZ+MHM4MX3o5iOlLyUnoGTMr0YfBVZ/9zomau/L7M02dnu6xUnfKcEMlIz4l0F5/z4I/dDP3Mm3jWJHQsWAC1a2L/ICosXA02bhr8k0RZwRvmfOsW/L313Xnm57QSNOs9LL+X/c3P58/j88/D02dn83s0mqumF3fjxwO23m9+Dsp+dDTRpAnz1lfb86dP8/ozMLX/9BVSrxs1beiZMAG67TXtM+X6XLTOuu5o6dbgnkBmKElDqb/Tbyc0Fli/n22PGaM99/TVQsmRon8iZCWbGDG52LFTIOt0ddwDjxoUfP3kyVK4ZiktyAuE7JZCQPPQQn0Tyxx/xrkns6N8f+P330CCpHQMH8uezd6/x+Vj2BFav5t/XqFHaNKI9AfX577/n///9lz+P/v3D0+/bx+994EDjfPQ9gSeeAN5+2zidmj17gHXrgHvu0R5fu5bf38iR4XV57TU+5jF7dvi5xx8H3n03/DgAPPig8XE127Zxjxs9+pa/la+/+vek2OIV9PcJAIMHW9dJ/8w++UQ7oG/0u5s9G5g0yTpfQPYEJCoS1b0vEry613g8s/R0/l//wrrpCSgogsWqJagXOEo5dgrQTGimpmrzUVCUmdF9RPu3qjxbozKjYd5y03iItHyrZyh7At5wUdWLAABZuTF2P3NLorr3idCxY7iJAAi9ADVq8BakmsaN+WQqJ0yaxH3OmzZ1V087DhwIbSueSNnZQNWqvOV9+jRw4438uPrlHjIEaNZMvJycHODvv/lvQm9uMuPnn8XMLQozZ3LTo/K7270bKKyKNqCYU379NdTLIOL3ZSbAjL5jI5S8cnO5V9P8+drz+la2OrzDzJlcQakFpahAbtqUm7KM6mPGSy/xkA5q7r2X99j0VK7MTU127zKRtSdblHsCRNSNiDYT0TYiesTgfEkimk9EfxDRCiK60C5PXyqBn3ZxW+a1718b55o4JBl7Aj/8YG4iALg7nt49c/16/gKKoH5m06Z5b1Iz+k4UIZmdzQXo+PFat1B1C3r6dG5eES0nOztkp37tNe05vYAR/b0YXb9nT6jFD2jHo5Ry9eTlmSsBq+9YXw+Ax/CZMyc8jVFPQGHcOJ7HKVU0ANGBXrPfRaQNMOX6vXvFwz4o4x4xNgcRUSqAaQC6A2gIoB8RNdQlewzAGsZYEwD9Abxol68vlcCgloMAAF9t/8ompU+Q5iD/YlRPM3OQ1TVWwkZxEAR4K9fMiyRSJaCvj1mdUkxea6uegNN6EBk7QigmKivUz13UscCMePbCY28OagNgG2NsB2MsC8BcAD11aRoCWMKrx/4EUJ2Iyltl6ksl0LJSy9gXumcP/0EpcUCc4EclUK0a0KaN9hgRULCg+QsMADt28POffuptfZSWtvoZVa8OPPZYeNoXXuB1EIlEq45JozZ9KHzzTfg1zzzD/6uF0S23hLYXLOC/A9HZooxpW7TKAKaRgJowQWueEaFCBf7/vvvCyzViwwbj4y+8wOPlAHxCllIPfQuYKHwAf926cAF/9mx4GUa9Az2dO4e2P/iA/w5at+blliplf72+rgCfvEXEezTDhzu/3gl9+vD/586FX1+rlvP8QqQR0SrVRz8SXhmA2q1rd+CYmrUAegEAEbUBUA1AFatCfakEzufEIZKi8uK8+abza/2oBP75h4clVlDqZhfxULnG7mX24l7//tv4+JQp/P/Ro/Z5vPCC9Xn1DFwr9HbiWbNCLoF25OWJP4/nngtti15jpLBTUpz7zD/9dGhb8WoCuElMj1VgNat6ibBtW2j7gw/478DKtdQK5d1T3lszU1hikMMYa6X6vKo7b6Sx9D+iiQBKEtEaAPcD+B2AZffEl0pAPSB8zwIDt7BoEIkgV1+bne0uHPCZM9qu8dmzIUF97pz3tkazFpByXC8As7L4fZ09qzV5OEXkunPnwo8p5errZPVccnKMW6sidUlNFf8e8/KM0545o31WjGldII3u0wx92pMnnZse1EJbHXrCTNlZRUndssXcPfjIEfHfa6STv/S/Y6fv3rFjkZUfW3YDqKrarwJA02VjjJ1gjA1gjDUDHxMoC2CnVaa+VwKv/RZjzR6pEmjUiJtcnFK4sNYkUagQ91Y4fpzPcIx0oRH9fdkpgS+/1LYWy5UDihbl9eqpN0M6QOSlV+K2q+tcqBBw1VXadGXLWocy7tYt3NdclJQUoIqqF21lNsjL06ZVeO89oGvX0H3oYxJ17y5eH/0s12bNgEGDxK8HzO3vHTqEHyMCatc2z6t5c3OTYZkyfB1gEbxw11RPivv1V2fX5+YaT5zzJysB1CGiGkSUAaAvgM/UCYioROAcANwN4HvGmOVkHl8qgS61usS+UK96Alu
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"iters = [item[\"iter\"] for item in result]\n",
"fig, ax1 = plt.subplots()\n",
"ax1.plot(iters,[item[\"loss\"] for item in result],'g')\n",
"ax2 = ax1.twinx()\n",
"ax2.plot(iters,[item[\"accuracy\"] for item in result],'r')\n",
"plt.xlabel(\"iter\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 91,
"id": "cab7889f",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpu allocated : 1776 MB\n",
"gpu reserved : 1910 MB\n"
]
}
],
"source": [
"torch.cuda.empty_cache()\n",
"print(f\"gpu allocated : {torch.cuda.memory_allocated() // 1024**2} MB\")\n",
"print(f\"gpu reserved : {torch.cuda.memory_reserved() // 1024 ** 2} MB\")"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "29ffab84",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \"model.zip\")"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "4b9b9579",
"metadata": {},
"outputs": [],
"source": [
"del batch_inputs\n",
"del batch_labels\n",
"del loss\n",
"del optimizer"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "fff7a7d0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████| 3125/3125 [01:26<00:00, 36.25batch/s]\n"
]
}
],
"source": [
"model.eval()\n",
"collect_list = []\n",
"with torch.no_grad():\n",
" with tqdm(test_loader, unit=\"batch\") as tepoch:\n",
" for batch_i,batch_l in tepoch:\n",
" batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n",
" batch_labels = batch_l.cuda(device)\n",
" output = model(**batch_inputs)\n",
" loss = BCELoss(output, batch_labels.double())\n",
" \n",
" prediction = (output > 0).to(device,dtype=torch.int64)\n",
" correct = (prediction == batch_labels).sum().item()\n",
" accuracy = correct / prediction.size()[0]\n",
" \n",
" collect_list.append({\"loss\":loss.item(),\"accuracy\":accuracy, \"batch_size\":batch_labels.size(0),\n",
" \"predict\":prediction.cpu(),\n",
" \"actual\":batch_labels.cpu()})"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "4e9a90b5",
"metadata": {},
"outputs": [],
"source": [
"def getConfusionMatrix(predict,actual,attention_mask):\n",
" ret = torch.zeros((2,2),dtype=torch.long)\n",
" for p_s,a_s in zip(predict,actual):\n",
" ret[p_s,a_s] += 1\n",
" return ret"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "b7a513c9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average_loss : 0.3252932393981423, average_accuracy : 0.86136, size :50000\n"
]
}
],
"source": [
"total_loss = 0\n",
"total_accuracy = 0\n",
"total_size = 0\n",
"confusion = torch.zeros((2,2),dtype=torch.long)\n",
"\n",
"for item in collect_list:\n",
" batch_size = item[\"batch_size\"]\n",
" total_loss += batch_size * item[\"loss\"]\n",
" total_accuracy += batch_size * item[\"accuracy\"]\n",
" total_size += batch_size\n",
" confusion += getConfusionMatrix(item[\"predict\"],item[\"actual\"],item[\"attention_mask\"])\n",
"print(f\"\"\"average_loss : {total_loss/total_size}, average_accuracy : {total_accuracy/total_size}, size :{total_size}\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "1ac327de",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[21382, 3487],\n",
" [ 3445, 21686]])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "3e71d4d2",
"metadata": {},
"outputs": [],
"source": [
"def getF1Score(confusion,c):\n",
" TP = confusion[c,c]\n",
" FP = confusion[c].sum() - TP\n",
" FN = confusion[:,c].sum() - TP\n",
" precision = TP / (TP + FP)\n",
" recall = TP / (TP + FN)\n",
"\n",
" f1Score = (2*precision*recall)/(precision + recall)\n",
" return f1Score"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "6756408c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f1 score : 0.862197756767273\n"
]
}
],
"source": [
"print(f\"f1 score : {getF1Score(confusion,1)}\")"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "f28f64e9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (lin1): Linear(in_features=768, out_features=256, bias=True)\n",
" (lin2): Linear(in_features=256, out_features=1, bias=True)\n",
")"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.eval()"
]
},
{
"cell_type": "markdown",
"id": "2da5789b",
"metadata": {},
"source": [
"한번 테스트해보기"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "cc727fd9",
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "Interrupted by user",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_10708/4160447663.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0msen\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0minputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msen\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'pt'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpadding\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'longest'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m{\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mv\u001b[0m \u001b[1;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m}\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mprob\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"긍정적 output :\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mprob\u001b[0m \u001b[1;33m*\u001b[0m \u001b[1;36m100\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"%\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\ipykernel\\kernelbase.py\u001b[0m in \u001b[0;36mraw_input\u001b[1;34m(self, prompt)\u001b[0m\n\u001b[0;32m 1008\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_parent_ident\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m\"shell\"\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1009\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_parent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"shell\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1010\u001b[1;33m \u001b[0mpassword\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1011\u001b[0m )\n\u001b[0;32m 1012\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\ipykernel\\kernelbase.py\u001b[0m in \u001b[0;36m_input_request\u001b[1;34m(self, prompt, ident, parent, password)\u001b[0m\n\u001b[0;32m 1049\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1050\u001b[0m \u001b[1;31m# re-raise KeyboardInterrupt, to truncate traceback\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1051\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Interrupted by user\"\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1052\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1053\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwarning\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Invalid Message:\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexc_info\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: Interrupted by user"
]
}
],
"source": [
"sen = input()\n",
"inputs = tokenizer(sen, return_tensors = 'pt', padding='longest')\n",
"output = model(**{k: v.to(device) for k,v in inputs.items() })\n",
"prob = torch.sigmoid(output).item()\n",
"print(\"긍정적 output :\",prob * 100,\"%\")\n",
"print(\"부정적 output :\", (1-prob) * 100,\"%\")"
]
},
{
"cell_type": "markdown",
"id": "2faa8141",
"metadata": {},
"source": [
"```\n",
"5471412\t맘에 들어요~ 0\n",
"```\n",
"라벨이 잘못 붙어있는 것들이 있다. 별점가지고만 긍정, 부정을 매긴 것 같다."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b40f071c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}