154 lines
4.3 KiB
Plaintext
154 lines
4.3 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "c916dd3b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
|
|
]
|
|
},
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import sys\n",
|
|
"sys.executable"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "d5861234",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import List\n",
|
|
"from torch.utils.data import Dataset\n",
|
|
"import torch\n",
|
|
"from transformers import PreTrainedTokenizer\n",
|
|
"from ndata import readNsmcRawData, NsmcRawData"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "5accd3a9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"load bert tokenizer...\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████████████████████████████████████████████████████████████████| 150000/150000 [00:00<00:00, 205761.43it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from transformers import BertTokenizer\n",
|
|
"print(\"load bert tokenizer...\")\n",
|
|
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
|
|
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)\n",
|
|
"\n",
|
|
"data = readNsmcRawData(\"nsmc/nsmc-master/ratings_train.txt\",use_tqdm=True,total=150000)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d10fcb83",
|
|
"metadata": {},
|
|
"source": [
|
|
"data를 준비"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "552fe555",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"({'input_ids': tensor([[ 101, 9519, 9074, 119005, 119, 119, 9708, 119235, 9715,\n",
|
|
" 119230, 16439, 77884, 48549, 9284, 22333, 12692, 102, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0],\n",
|
|
" [ 101, 100, 119, 119, 119, 9928, 58823, 30005, 11664,\n",
|
|
" 9757, 118823, 30858, 18227, 119219, 119, 119, 119, 119,\n",
|
|
" 9580, 41605, 25486, 12310, 20626, 23466, 8843, 118986, 12508,\n",
|
|
" 9523, 17196, 16439, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
|
|
" 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
|
" 1, 1, 1, 1, 1, 1, 1]])}, tensor([0, 1]))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def make_collate_fn(tokenzier: PreTrainedTokenizer):\n",
|
|
" def collate_fn(batch: List[NsmcRawData]):\n",
|
|
" labels = [s.label for s in batch]\n",
|
|
" return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)\n",
|
|
" return collate_fn\n",
|
|
"\n",
|
|
"collate = make_collate_fn(tokenizer)\n",
|
|
"print(collate(data[0:2]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1cff8e03",
|
|
"metadata": {},
|
|
"source": [
|
|
"간단한 collate function"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "89eb64d8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|