From fbc8a25e30a4d11621dcbabc7d902044b5ac4574 Mon Sep 17 00:00:00 2001 From: monoid Date: Wed, 23 Feb 2022 19:46:29 +0900 Subject: [PATCH] feat: batch and collate function --- Batch.ipynb | 153 ++++++++++++++++++++++++++++++++++++++++++++++++++++ ndataset.py | 41 ++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 Batch.ipynb create mode 100644 ndataset.py diff --git a/Batch.ipynb b/Batch.ipynb new file mode 100644 index 0000000..6eef219 --- /dev/null +++ b/Batch.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c916dd3b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sys\n", + "sys.executable" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d5861234", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List\n", + "from torch.utils.data import Dataset\n", + "import torch\n", + "from transformers import PreTrainedTokenizer\n", + "from ndata import readNsmcRawData, NsmcRawData" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5accd3a9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "load bert tokenizer...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████| 150000/150000 [00:00<00:00, 205761.43it/s]\n" + ] + } + ], + "source": [ + "from transformers import BertTokenizer\n", + "print(\"load bert tokenizer...\")\n", + "PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n", + "tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)\n", + "\n", + "data = readNsmcRawData(\"nsmc/nsmc-master/ratings_train.txt\",use_tqdm=True,total=150000)" + ] + }, + { + "cell_type": "markdown", + "id": "d10fcb83", + "metadata": {}, + "source": [ + "data를 준비" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "552fe555", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "({'input_ids': tensor([[ 101, 9519, 9074, 119005, 119, 119, 9708, 119235, 9715,\n", + " 119230, 16439, 77884, 48549, 9284, 22333, 12692, 102, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 100, 119, 119, 119, 9928, 58823, 30005, 11664,\n", + " 9757, 118823, 30858, 18227, 119219, 119, 119, 119, 119,\n", + " 9580, 41605, 25486, 12310, 20626, 23466, 8843, 118986, 12508,\n", + " 9523, 17196, 16439, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1]])}, tensor([0, 1]))\n" + ] + } + ], + "source": [ + "def make_collate_fn(tokenzier: PreTrainedTokenizer):\n", + " def collate_fn(batch: List[NsmcRawData]):\n", + " labels = [s.label for s in batch]\n", + " return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)\n", + " return collate_fn\n", + "\n", + "collate = make_collate_fn(tokenizer)\n", + "print(collate(data[0:2]))" + ] + }, + { + "cell_type": "markdown", + "id": "1cff8e03", + "metadata": {}, + "source": [ + "간단한 collate function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89eb64d8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ndataset.py b/ndataset.py new file mode 100644 index 0000000..c06b7ba --- /dev/null +++ b/ndataset.py @@ -0,0 +1,41 @@ +import sys +from typing import List +from torch.utils.data import Dataset +import torch +from transformers import PreTrainedTokenizer +from ndata import readNsmcRawData, NsmcRawData + +def readNsmcDataAll(): + """ + Returns: train, test + """ + print("read train set", file=sys.stderr) + train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000) + print("read test set", file=sys.stderr) + test = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000) + return NsmcDataset(train),NsmcDataset(test) + +class NsmcDataset(Dataset): + def __init__(self, data: List[NsmcRawData]): + self.x = data + def __len__(self): + return len(self.x) + def __getitem__(self, idx): + return self.x[idx] + +def make_collate_fn(tokenzier: PreTrainedTokenizer): + def collate_fn(batch: List[NsmcRawData]): + labels = [s.label for s in batch] + return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels) + return collate_fn + +if __name__ == "__main__": + from transformers import BertTokenizer + print("load bert tokenizer...") + PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased' + tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME) + + data = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150000) + + collate = make_collate_fn(tokenizer) + print(collate(data[0:2]))