nsmc-study/Batch.ipynb

154 lines
4.3 KiB
Plaintext
Raw Normal View History

2022-02-23 19:46:29 +09:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c916dd3b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sys\n",
"sys.executable"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d5861234",
"metadata": {},
"outputs": [],
"source": [
"from typing import List\n",
"from torch.utils.data import Dataset\n",
"import torch\n",
"from transformers import PreTrainedTokenizer\n",
"from ndata import readNsmcRawData, NsmcRawData"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5accd3a9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"load bert tokenizer...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████| 150000/150000 [00:00<00:00, 205761.43it/s]\n"
]
}
],
"source": [
"from transformers import BertTokenizer\n",
"print(\"load bert tokenizer...\")\n",
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)\n",
"\n",
"data = readNsmcRawData(\"nsmc/nsmc-master/ratings_train.txt\",use_tqdm=True,total=150000)"
]
},
{
"cell_type": "markdown",
"id": "d10fcb83",
"metadata": {},
"source": [
"data를 준비"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "552fe555",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"({'input_ids': tensor([[ 101, 9519, 9074, 119005, 119, 119, 9708, 119235, 9715,\n",
" 119230, 16439, 77884, 48549, 9284, 22333, 12692, 102, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0],\n",
" [ 101, 100, 119, 119, 119, 9928, 58823, 30005, 11664,\n",
" 9757, 118823, 30858, 18227, 119219, 119, 119, 119, 119,\n",
" 9580, 41605, 25486, 12310, 20626, 23466, 8843, 118986, 12508,\n",
" 9523, 17196, 16439, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1]])}, tensor([0, 1]))\n"
]
}
],
"source": [
"def make_collate_fn(tokenzier: PreTrainedTokenizer):\n",
" def collate_fn(batch: List[NsmcRawData]):\n",
" labels = [s.label for s in batch]\n",
" return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)\n",
" return collate_fn\n",
"\n",
"collate = make_collate_fn(tokenizer)\n",
"print(collate(data[0:2]))"
]
},
{
"cell_type": "markdown",
"id": "1cff8e03",
"metadata": {},
"source": [
"간단한 collate function"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89eb64d8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}