feat: batch and collate function
This commit is contained in:
parent
58fba0bd3c
commit
fbc8a25e30
153
Batch.ipynb
Normal file
153
Batch.ipynb
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "c916dd3b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import sys\n",
|
||||||
|
"sys.executable"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "d5861234",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from typing import List\n",
|
||||||
|
"from torch.utils.data import Dataset\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from transformers import PreTrainedTokenizer\n",
|
||||||
|
"from ndata import readNsmcRawData, NsmcRawData"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "5accd3a9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"load bert tokenizer...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████| 150000/150000 [00:00<00:00, 205761.43it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from transformers import BertTokenizer\n",
|
||||||
|
"print(\"load bert tokenizer...\")\n",
|
||||||
|
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
|
||||||
|
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)\n",
|
||||||
|
"\n",
|
||||||
|
"data = readNsmcRawData(\"nsmc/nsmc-master/ratings_train.txt\",use_tqdm=True,total=150000)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d10fcb83",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"data를 준비"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "552fe555",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"({'input_ids': tensor([[ 101, 9519, 9074, 119005, 119, 119, 9708, 119235, 9715,\n",
|
||||||
|
" 119230, 16439, 77884, 48549, 9284, 22333, 12692, 102, 0,\n",
|
||||||
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||||
|
" 0, 0, 0, 0],\n",
|
||||||
|
" [ 101, 100, 119, 119, 119, 9928, 58823, 30005, 11664,\n",
|
||||||
|
" 9757, 118823, 30858, 18227, 119219, 119, 119, 119, 119,\n",
|
||||||
|
" 9580, 41605, 25486, 12310, 20626, 23466, 8843, 118986, 12508,\n",
|
||||||
|
" 9523, 17196, 16439, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||||
|
" 0, 0, 0, 0, 0, 0, 0],\n",
|
||||||
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||||
|
" 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||||
|
" 0, 0, 0, 0, 0, 0, 0],\n",
|
||||||
|
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
||||||
|
" 1, 1, 1, 1, 1, 1, 1]])}, tensor([0, 1]))\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"def make_collate_fn(tokenzier: PreTrainedTokenizer):\n",
|
||||||
|
" def collate_fn(batch: List[NsmcRawData]):\n",
|
||||||
|
" labels = [s.label for s in batch]\n",
|
||||||
|
" return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)\n",
|
||||||
|
" return collate_fn\n",
|
||||||
|
"\n",
|
||||||
|
"collate = make_collate_fn(tokenizer)\n",
|
||||||
|
"print(collate(data[0:2]))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "1cff8e03",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"간단한 collate function"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "89eb64d8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.11"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
41
ndataset.py
Normal file
41
ndataset.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import sys
|
||||||
|
from typing import List
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import torch
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
from ndata import readNsmcRawData, NsmcRawData
|
||||||
|
|
||||||
|
def readNsmcDataAll():
|
||||||
|
"""
|
||||||
|
Returns: train, test
|
||||||
|
"""
|
||||||
|
print("read train set", file=sys.stderr)
|
||||||
|
train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000)
|
||||||
|
print("read test set", file=sys.stderr)
|
||||||
|
test = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
|
||||||
|
return NsmcDataset(train),NsmcDataset(test)
|
||||||
|
|
||||||
|
class NsmcDataset(Dataset):
|
||||||
|
def __init__(self, data: List[NsmcRawData]):
|
||||||
|
self.x = data
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.x)
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.x[idx]
|
||||||
|
|
||||||
|
def make_collate_fn(tokenzier: PreTrainedTokenizer):
|
||||||
|
def collate_fn(batch: List[NsmcRawData]):
|
||||||
|
labels = [s.label for s in batch]
|
||||||
|
return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)
|
||||||
|
return collate_fn
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
print("load bert tokenizer...")
|
||||||
|
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
|
||||||
|
|
||||||
|
data = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150000)
|
||||||
|
|
||||||
|
collate = make_collate_fn(tokenizer)
|
||||||
|
print(collate(data[0:2]))
|
Loading…
Reference in New Issue
Block a user