feat: train, dev, test

This commit is contained in:
monoid 2022-02-27 19:50:13 +09:00
parent 9fcd0786b1
commit 8a1442995b
3 changed files with 656 additions and 98 deletions

View File

@ -3,7 +3,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"id": "c916dd3b", "id": "5a4a1e30",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -25,7 +25,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"id": "d5861234", "id": "710cd5b2",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -39,7 +39,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"id": "5accd3a9", "id": "da018ffe",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -68,7 +68,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "d10fcb83", "id": "69f05cf6",
"metadata": {}, "metadata": {},
"source": [ "source": [
"data를 준비" "data를 준비"
@ -77,7 +77,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 7,
"id": "552fe555", "id": "961edd10",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -114,7 +114,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "1cff8e03", "id": "4178b576",
"metadata": {}, "metadata": {},
"source": [ "source": [
"간단한 collate function" "간단한 collate function"
@ -123,7 +123,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "89eb64d8", "id": "a5ff0049",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

File diff suppressed because one or more lines are too long

View File

@ -7,13 +7,15 @@ from ndata import readNsmcRawData, NsmcRawData
def readNsmcDataAll(): def readNsmcDataAll():
""" """
Returns: train, test Returns: train, dev, test
""" """
print("read train set", file=sys.stderr) print("read train set", file=sys.stderr)
train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000) train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000)
print("read test set", file=sys.stderr) print("read test set", file=sys.stderr)
test = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000) testBig = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
return NsmcDataset(train),NsmcDataset(test) test = testBig[:30_000]
dev = testBig[30_000:]
return NsmcDataset(train),NsmcDataset(dev),NsmcDataset(test)
class NsmcDataset(Dataset): class NsmcDataset(Dataset):
def __init__(self, data: List[NsmcRawData]): def __init__(self, data: List[NsmcRawData]):