{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "811f97e3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sys\n", "sys.executable" ] }, { "cell_type": "markdown", "id": "9993f053", "metadata": {}, "source": [ "파이썬 환경 확인.\n", "envs\\\\nn\\\\python.exe 으로 끝나기를 기대합니다." ] }, { "cell_type": "markdown", "id": "6748199b", "metadata": {}, "source": [ "### DataLoader에 대해서\n", "Batch로 묶을 DataLoader를 다뤄봐요." ] }, { "cell_type": "code", "execution_count": 2, "id": "1fbc6696", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "from read_data import TagIdConverter\n", "from preprocessing import readPreporcssedDataAll" ] }, { "cell_type": "code", "execution_count": 3, "id": "a24ef0d4", "metadata": {}, "outputs": [], "source": [ "tagIdConverter = TagIdConverter()" ] }, { "cell_type": "code", "execution_count": 4, "id": "5f2e0d22", "metadata": {}, "outputs": [], "source": [ "train, dev, test = readPreporcssedDataAll()" ] }, { "cell_type": "code", "execution_count": 6, "id": "08ab5f11", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'tokens': ['특히',\n", " '김',\n", " '##병',\n", " '##현',\n", " '은',\n", " '4',\n", " '회',\n", " '말',\n", " '에',\n", " '무',\n", " '기',\n", " '##력',\n", " '하',\n", " '게',\n", " '6',\n", " '실',\n", " '##점',\n", " '하',\n", " '면',\n", " '##서'],\n", " 'ids': [39671,\n", " 8935,\n", " 73380,\n", " 30842,\n", " 9632,\n", " 125,\n", " 9998,\n", " 9251,\n", " 9559,\n", " 9294,\n", " 8932,\n", " 28143,\n", " 9952,\n", " 8872,\n", " 127,\n", " 9489,\n", " 34907,\n", " 9952,\n", " 9279,\n", " 12424],\n", " 'entity': ['O',\n", " 'B-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O'],\n", " 'entity_ids': [21,\n", " 7,\n", " 17,\n", " 17,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21]},\n", " {'tokens': ['빅',\n", " '##비',\n", " '가',\n", " '2',\n", " '루',\n", " '도',\n", " '##루',\n", " '를',\n", " '시',\n", " '##도',\n", " '하',\n", " '다',\n", " '아',\n", " '##웃',\n", " '됐',\n", " '고'],\n", " 'ids': [9380,\n", " 29455,\n", " 8843,\n", " 123,\n", " 9213,\n", " 9087,\n", " 35866,\n", " 9233,\n", " 9485,\n", " 12092,\n", " 9952,\n", " 9056,\n", " 9519,\n", " 119170,\n", " 9097,\n", " 8888],\n", " 'entity': ['B-PS',\n", " 'I-PS',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O'],\n", " 'entity_ids': [7,\n", " 17,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21]},\n", " {'tokens': ['서',\n", " '##호',\n", " '##프',\n", " '와',\n", " '파',\n", " '##사',\n", " '##노',\n", " '에',\n", " '##게',\n", " '연속',\n", " '안',\n", " '##타',\n", " '를',\n", " '내',\n", " '##주',\n", " '며'],\n", " 'ids': [9425,\n", " 20309,\n", " 28396,\n", " 9590,\n", " 9901,\n", " 12945,\n", " 28981,\n", " 9559,\n", " 14153,\n", " 100208,\n", " 9521,\n", " 22695,\n", " 9233,\n", " 8996,\n", " 16323,\n", " 9278],\n", " 'entity': ['B-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'O',\n", " 'B-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O'],\n", " 'entity_ids': [7,\n", " 17,\n", " 17,\n", " 21,\n", " 7,\n", " 17,\n", " 17,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21]},\n", " {'tokens': ['우',\n", " '##리',\n", " '금',\n", " '##융',\n", " '그',\n", " '##룹',\n", " '은',\n", " '19',\n", " '일',\n", " '지',\n", " '##난',\n", " '##주',\n", " '부',\n", " '##터',\n", " '네',\n", " '##덜',\n", " '##란드',\n", " '프로',\n", " '축구',\n", " '에',\n", " '##인',\n", " '##트',\n", " '##호',\n", " '##벤',\n", " '의',\n", " '미',\n", " '##드',\n", " '##필',\n", " '##더',\n", " '박',\n", " '##지',\n", " '##성',\n", " '과',\n", " '미국',\n", " '프로',\n", " '야구',\n", " '텍',\n", " '##사',\n", " '##스',\n", " '레',\n", " '##인',\n", " '##저',\n", " '##스',\n", " '의',\n", " '투',\n", " '##수',\n", " '박',\n", " '##찬',\n", " '##호',\n", " '를',\n", " '모',\n", " '##델',\n", " '로',\n", " '기',\n", " '##용',\n", " ',',\n", " '공',\n", " '##중',\n", " '##파',\n", " '텔레비전',\n", " '을',\n", " '통해',\n", " '광',\n", " '##고',\n", " '를',\n", " '재',\n", " '##개',\n", " '했',\n", " '다',\n", " '##고',\n", " '밝',\n", " '##혔',\n", " '다',\n", " '.'],\n", " 'ids': [9604,\n", " 12692,\n", " 8928,\n", " 119184,\n", " 8924,\n", " 87114,\n", " 9632,\n", " 10270,\n", " 9641,\n", " 9706,\n", " 33305,\n", " 16323,\n", " 9365,\n", " 21876,\n", " 9011,\n", " 118783,\n", " 61592,\n", " 102574,\n", " 37905,\n", " 9559,\n", " 12030,\n", " 15184,\n", " 20309,\n", " 118979,\n", " 9637,\n", " 9309,\n", " 15001,\n", " 119416,\n", " 54141,\n", " 9319,\n", " 12508,\n", " 17138,\n", " 8898,\n", " 23545,\n", " 102574,\n", " 106603,\n", " 9867,\n", " 12945,\n", " 12605,\n", " 9186,\n", " 12030,\n", " 48387,\n", " 12605,\n", " 9637,\n", " 9881,\n", " 15891,\n", " 9319,\n", " 119249,\n", " 20309,\n", " 9233,\n", " 9283,\n", " 118791,\n", " 9202,\n", " 8932,\n", " 24974,\n", " 117,\n", " 8896,\n", " 41693,\n", " 46150,\n", " 97232,\n", " 9633,\n", " 25605,\n", " 8903,\n", " 11664,\n", " 9233,\n", " 9659,\n", " 21789,\n", " 9965,\n", " 9056,\n", " 11664,\n", " 9324,\n", " 119436,\n", " 9056,\n", " 119],\n", " 'entity': ['B-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'O',\n", " 'B-DT',\n", " 'I-DT',\n", " 'B-DT',\n", " 'I-DT',\n", " 'I-DT',\n", " 'I-DT',\n", " 'I-DT',\n", " 'B-LC',\n", " 'I-LC',\n", " 'I-LC',\n", " 'O',\n", " 'O',\n", " 'B-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'O',\n", " 'B-LC',\n", " 'O',\n", " 'O',\n", " 'B-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O'],\n", " 'entity_ids': [6,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 21,\n", " 4,\n", " 14,\n", " 4,\n", " 14,\n", " 14,\n", " 14,\n", " 14,\n", " 5,\n", " 15,\n", " 15,\n", " 21,\n", " 21,\n", " 6,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 7,\n", " 17,\n", " 17,\n", " 21,\n", " 5,\n", " 21,\n", " 21,\n", " 6,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 21,\n", " 21,\n", " 21,\n", " 7,\n", " 17,\n", " 17,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21]},\n", " {'tokens': ['새',\n", " '##미',\n", " '소',\n", " '##사',\n", " '(',\n", " '36',\n", " '.',\n", " '볼',\n", " '##티',\n", " '##모',\n", " '##어',\n", " '오',\n", " '##리',\n", " '##올',\n", " '##스',\n", " ')',\n", " '가',\n", " '은',\n", " '##퇴',\n", " '한',\n", " \"'\",\n", " '홈',\n", " '##런',\n", " '##왕',\n", " \"'\",\n", " '마',\n", " '##크',\n", " '맥',\n", " '##과',\n", " '##이어',\n", " '(',\n", " '전',\n", " '세',\n", " '##이트',\n", " '##루',\n", " '##이스',\n", " '카',\n", " '##디',\n", " '##널',\n", " '##스',\n", " ')',\n", " '와',\n", " '개',\n", " '##인',\n", " '통',\n", " '##산',\n", " '홈',\n", " '##런',\n", " '타',\n", " '##이',\n", " '를',\n", " '이',\n", " '##뤘',\n", " '다',\n", " '.'],\n", " 'ids': [9415,\n", " 22458,\n", " 9448,\n", " 12945,\n", " 113,\n", " 11055,\n", " 119,\n", " 9359,\n", " 45725,\n", " 39420,\n", " 12965,\n", " 9580,\n", " 12692,\n", " 119153,\n", " 12605,\n", " 114,\n", " 8843,\n", " 9632,\n", " 119362,\n", " 9954,\n", " 112,\n", " 9989,\n", " 56710,\n", " 40991,\n", " 112,\n", " 9246,\n", " 20308,\n", " 9259,\n", " 11882,\n", " 86732,\n", " 113,\n", " 9665,\n", " 9435,\n", " 41620,\n", " 35866,\n", " 48653,\n", " 9786,\n", " 48446,\n", " 49881,\n", " 12605,\n", " 114,\n", " 9590,\n", " 8857,\n", " 12030,\n", " 9879,\n", " 21386,\n", " 9989,\n", " 56710,\n", " 9845,\n", " 10739,\n", " 9233,\n", " 9638,\n", " 118896,\n", " 9056,\n", " 119],\n", " 'entity': ['B-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'I-PS',\n", " 'O',\n", " 'O',\n", " 'B-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'I-OG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O'],\n", " 'entity_ids': [7,\n", " 17,\n", " 17,\n", " 17,\n", " 21,\n", " 21,\n", " 21,\n", " 6,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 7,\n", " 17,\n", " 17,\n", " 17,\n", " 17,\n", " 21,\n", " 21,\n", " 6,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 16,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21,\n", " 21]}]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train[0:5]" ] }, { "cell_type": "markdown", "id": "d824e042", "metadata": {}, "source": [ "미리 처리된 데이터를 가지고 왔어요. `tokens`와 `entity`, `ids`, `entity_ids` 로 이루어져 있어요.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "9b14e6b5", "metadata": {}, "outputs": [], "source": [ "class DatasetArray(Dataset):\n", " def __init__(self, data):\n", " self.x = data\n", " def __len__(self):\n", " return len(self.x)\n", " def __getitem__(self, idx):\n", " return self.x[idx]" ] }, { "cell_type": "markdown", "id": "8e659322", "metadata": {}, "source": [ "아무 Array를 Dataset으로 받는 것" ] }, { "cell_type": "code", "execution_count": 8, "id": "ce7f1185", "metadata": {}, "outputs": [], "source": [ "from typing import Any, List\n", "def padding_array(data: List[List[Any]], padding_value = 0, max_length = None):\n", " \"\"\"\n", " padding array of array\n", "\n", " >>> padding_array([[1,2],[3]])\n", " [[1,2],[3,0]]\n", " \"\"\"\n", " if max_length is None:\n", " max_length = max([len(lst) for lst in data])\n", " return [lst + [padding_value] * (max_length - len(lst)) for lst in data]" ] }, { "cell_type": "markdown", "id": "2244b614", "metadata": {}, "source": [ "padding 하는 함수" ] }, { "cell_type": "code", "execution_count": 9, "id": "6a6c9a65", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[1, 2], [3, 0]]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "padding_array([[1,2],[3]])" ] }, { "cell_type": "code", "execution_count": 10, "id": "86486b49", "metadata": {}, "outputs": [], "source": [ "def my_collate_fn(batch):\n", " words = [item[\"tokens\"] for item in batch]\n", " entities = [item[\"entity\"] for item in batch]\n", " words = padding_array(words,padding_value= \"[PAD]\")\n", " entities = padding_array(entities,padding_value=\"[PAD]\")\n", " return words, entities" ] }, { "cell_type": "markdown", "id": "69b40db0", "metadata": {}, "source": [ "batch를 합치는 함수" ] }, { "cell_type": "code", "execution_count": 11, "id": "0dcf6cfa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "([['특히', '김', '##병', '##현', '은', '4', '회', '말', '에', '무', '기', '##력', '하', '게', '6', '실', '##점', '하', '면', '##서'], ['빅', '##비', '가', '2', '루', '도', '##루', '를', '시', '##도', '하', '다', '아', '##웃', '됐', '고', '[PAD]', '[PAD]', '[PAD]', '[PAD]']], [['O', 'B-PS', 'I-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['B-PS', 'I-PS', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '[PAD]', '[PAD]', '[PAD]', '[PAD]']])\n" ] } ], "source": [ "print(my_collate_fn(train[0:2]))" ] }, { "cell_type": "markdown", "id": "c7b00df0", "metadata": {}, "source": [ "잘 작동함." ] }, { "cell_type": "code", "execution_count": 12, "id": "8ee0d8dd", "metadata": {}, "outputs": [], "source": [ "train_loader = DataLoader(\n", " DatasetArray(train),\n", " batch_size=10,\n", " shuffle=True,\n", " collate_fn=my_collate_fn\n", ")" ] }, { "cell_type": "markdown", "id": "929f8e72", "metadata": {}, "source": [ "데이터 로더는 다음과 같이 사용" ] }, { "cell_type": "code", "execution_count": 13, "id": "129aa7e7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('동', 'B-OG') ('##부', 'I-OG') ('는', 'O') ('5', 'B-DT') ('일', 'I-DT') ('홈', 'O') ('##구', 'O') ('##장', 'O') ('원', 'B-LC') ('##주', 'I-LC') ('치', 'O') ('##악', 'O') ('체', 'O') ('##육', 'O') ('##관', 'O') ('에서', 'O') ('열린', 'O') ('2007', 'B-DT') ('-', 'I-DT') ('08', 'I-DT') ('SK', 'O') ('텔', 'O') ('##레', 'O') ('##콤', 'O') ('T', 'O') ('프로', 'O') ('농', 'O') ('##구', 'O') ('안', 'B-OG') ('##양', 'I-OG') ('K', 'I-OG') ('##T', 'I-OG') ('&', 'I-OG') ('G', 'I-OG') ('카', 'I-OG') ('##이', 'I-OG') ('##츠', 'I-OG') ('와', 'O') ('의', 'O') ('4', 'O') ('강', 'O') ('플', 'O') ('##레', 'O') ('##이', 'O') ('##오', 'O') ('##프', 'O') ('1', 'O') ('차', 'O') ('전', 'O') ('에서', 'O') ('26', 'O') ('점', 'O') ('7', 'O') ('리', 'O') ('##바', 'O') ('##운드', 'O') ('4', 'O') ('블', 'O') ('##록', 'O') ('##슛', 'O') ('을', 'O') ('기', 'O') ('##록', 'O') ('한', 'O') ('김', 'B-PS') ('##주', 'I-PS') ('##성', 'I-PS') ('과', 'O') ('3', 'O') ('점', 'O') ('슛', 'O') ('3', 'O') ('개', 'O') ('를', 'O') ('터', 'O') ('##뜨', 'O') ('##린', 'O') ('양', 'B-PS') ('##경', 'I-PS') ('##민', 'I-PS') ('의', 'O') ('활', 'O') ('##약', 'O') ('으로', 'O') ('73', 'O') ('-', 'O') ('62', 'O') ('로', 'O') ('승', 'O') ('##리', 'O') ('했', 'O') ('다', 'O') ('.', 'O')\n", "10\n", "10\n" ] } ], "source": [ "for inputs,label in train_loader:\n", " print(*zip(inputs[0],label[0]))\n", " print(len(inputs))\n", " print(len(label))\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "7fa76eb5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 5 }