sgd_hw/mnist_load.py

41 lines
1.1 KiB
Python
Raw Permalink Normal View History

2021-02-25 21:34:10 +09:00
import os
import pickle
import random
from sklearn import datasets
import numpy as np
PICKLE_DATA_FILENAME = "mnist.pickle"
train_x = None
train_y = None
dev_x = None
dev_y = None
test_x = None
test_y = None
def load_mnistdata():
global train_x, train_y, dev_x, dev_y, test_x, test_y
if not os.path.exists(PICKLE_DATA_FILENAME):
X, y = datasets.fetch_openml('mnist_784', return_X_y=True, cache=True, as_frame= False)
with open(PICKLE_DATA_FILENAME,"wb") as file:
pickle.dump(X,file)
pickle.dump(y,file)
else:
with open(PICKLE_DATA_FILENAME,"rb") as file:
X = pickle.load(file)
y = pickle.load(file)
#i = random.randint(0,len(X) - 1)
#plt.imshow(X[0].reshape(28,28),cmap='gray',interpolation='none')
#plt.show()
#simple normalize
X = X / 255
y = np.array([int(i) for i in y])
Y = np.eye(10)[y]
train_x,train_y = X[0:3500*17], Y[0:3500*17]
dev_x,dev_y = X[3500*17:3500*18], Y[3500*17:3500*18]
test_x,test_y = X[3500*18:3500*20], Y[3500*18:3500*20]
return ((train_x, train_y),(dev_x,dev_y),(test_x,test_y))