Compare commits
2 Commits
bc504fce74
...
8a1442995b
Author | SHA1 | Date | |
---|---|---|---|
8a1442995b | |||
9fcd0786b1 |
3
.gitignore
vendored
3
.gitignore
vendored
@ -2,4 +2,5 @@ nsmc/**/*
|
||||
nsmc.zip
|
||||
.ipynb_checkpoints/**/*
|
||||
__pycache__/**/*
|
||||
model.zip
|
||||
model.zip
|
||||
model/**/*
|
14
Batch.ipynb
14
Batch.ipynb
@ -3,7 +3,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c916dd3b",
|
||||
"id": "5a4a1e30",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -25,7 +25,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d5861234",
|
||||
"id": "710cd5b2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -39,7 +39,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "5accd3a9",
|
||||
"id": "da018ffe",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -68,7 +68,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d10fcb83",
|
||||
"id": "69f05cf6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"data를 준비"
|
||||
@ -77,7 +77,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "552fe555",
|
||||
"id": "961edd10",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -114,7 +114,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1cff8e03",
|
||||
"id": "4178b576",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"간단한 collate function"
|
||||
@ -123,7 +123,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "89eb64d8",
|
||||
"id": "a5ff0049",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
732
Training.ipynb
732
Training.ipynb
File diff suppressed because one or more lines are too long
@ -7,13 +7,15 @@ from ndata import readNsmcRawData, NsmcRawData
|
||||
|
||||
def readNsmcDataAll():
|
||||
"""
|
||||
Returns: train, test
|
||||
Returns: train, dev, test
|
||||
"""
|
||||
print("read train set", file=sys.stderr)
|
||||
train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000)
|
||||
print("read test set", file=sys.stderr)
|
||||
test = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
|
||||
return NsmcDataset(train),NsmcDataset(test)
|
||||
testBig = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
|
||||
test = testBig[:30_000]
|
||||
dev = testBig[30_000:]
|
||||
return NsmcDataset(train),NsmcDataset(dev),NsmcDataset(test)
|
||||
|
||||
class NsmcDataset(Dataset):
|
||||
def __init__(self, data: List[NsmcRawData]):
|
||||
|
Loading…
Reference in New Issue
Block a user