sgd_hw/p2data.py

70 lines
1.9 KiB
Python
Raw Normal View History

2021-02-13 00:26:40 +09:00
import numpy as np
import pickle
import math
DIMENTION = 3
VAR_RANGE = 1
N = 20
ALPHA = 0.1
SIGMA = VAR_RANGE * ALPHA
TrainingSetFilename = "trainingset.pickle"
TrainingSetRatio = 0.85
DevSetFilename = "devset.pickle"
DevSetRatio = 0.05
TestSetFilename = "testset.pickle"
TestSetRatio = 0.10
AnswerFilename = "answer.pickle"
def LoadAnswer():
with open(AnswerFilename,"rb") as fr:
w = pickle.load(fr)
b = pickle.load(fr)
def SaveAnswer(w,b):
with open(AnswerFilename,"wb") as fw:
pickle.dump(w, fw)
pickle.dump(b, fw)
def LoadData(filename):
with open(filename,"rb") as fr:
x = pickle.load(fr)
y = pickle.load(fr)
return x, y
def SaveData(filename,x,y):
with open(filename,"wb") as fw:
pickle.dump(x, fw)
pickle.dump(y, fw)
if __name__ == "__main__":
gen: np.random.Generator = np.random.default_rng()
or_weight = (gen.uniform(high = VAR_RANGE,low = -VAR_RANGE,size=(DIMENTION,1)))
or_bias = (gen.uniform(high = VAR_RANGE, low = -VAR_RANGE,size=()))
input_x = gen.uniform(low = -VAR_RANGE,high=VAR_RANGE,size=(N,DIMENTION))
y = input_x @ or_weight + or_bias
error = gen.normal(0,SIGMA,size = (N,1))
y += error
"""
p = gen.permutation(N)
x,y= x[p],y[p]
"""
print("success to generate dataset")
SaveAnswer(or_weight,or_bias)
TrainingIndex = math.ceil(N * TrainingSetRatio)
TrainingSetX = input_x[0:TrainingIndex]
TrainingSetY = y[0:TrainingIndex]
SaveData(TrainingSetFilename,TrainingSetX,TrainingSetY)
DevSetIndex = math.ceil(TrainingIndex + N * DevSetRatio)
DevSetX = input_x[TrainingIndex:DevSetIndex]
DevSetY = y[TrainingIndex:DevSetIndex]
SaveData(DevSetFilename,DevSetX,DevSetY)
TestSetX = input_x[DevSetIndex:N]
TestSetY = y[DevSetIndex:N]
SaveData(TestSetFilename,TestSetX,TestSetY)
print("success to save dataset")