feat: train, dev, test

This commit is contained in:
monoid 2022-02-27 19:50:13 +09:00
parent 9fcd0786b1
commit 8a1442995b
3 changed files with 656 additions and 98 deletions

View File

@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "c916dd3b",
"id": "5a4a1e30",
"metadata": {},
"outputs": [
{
@ -25,7 +25,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "d5861234",
"id": "710cd5b2",
"metadata": {},
"outputs": [],
"source": [
@ -39,7 +39,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "5accd3a9",
"id": "da018ffe",
"metadata": {},
"outputs": [
{
@ -68,7 +68,7 @@
},
{
"cell_type": "markdown",
"id": "d10fcb83",
"id": "69f05cf6",
"metadata": {},
"source": [
"data를 준비"
@ -77,7 +77,7 @@
{
"cell_type": "code",
"execution_count": 7,
"id": "552fe555",
"id": "961edd10",
"metadata": {},
"outputs": [
{
@ -114,7 +114,7 @@
},
{
"cell_type": "markdown",
"id": "1cff8e03",
"id": "4178b576",
"metadata": {},
"source": [
"간단한 collate function"
@ -123,7 +123,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "89eb64d8",
"id": "a5ff0049",
"metadata": {},
"outputs": [],
"source": []

File diff suppressed because one or more lines are too long

View File

@ -7,13 +7,15 @@ from ndata import readNsmcRawData, NsmcRawData
def readNsmcDataAll():
"""
Returns: train, test
Returns: train, dev, test
"""
print("read train set", file=sys.stderr)
train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000)
print("read test set", file=sys.stderr)
test = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
return NsmcDataset(train),NsmcDataset(test)
testBig = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
test = testBig[:30_000]
dev = testBig[30_000:]
return NsmcDataset(train),NsmcDataset(dev),NsmcDataset(test)
class NsmcDataset(Dataset):
def __init__(self, data: List[NsmcRawData]):