68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
from sklearn import datasets
|
|
import numpy as np
|
|
from layer import *
|
|
import os
|
|
import pickle
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
import random
|
|
|
|
matplotlib.use("TkAgg")
|
|
|
|
PICKLE_DATA_FILENAME = "mnist.pickle"
|
|
if not os.path.exists(PICKLE_DATA_FILENAME):
|
|
X, y = datasets.fetch_openml('mnist_784', return_X_y=True, cache=True, as_frame= False)
|
|
with open(PICKLE_DATA_FILENAME,"wb") as file:
|
|
pickle.dump(X,file)
|
|
pickle.dump(y,file)
|
|
else:
|
|
with open(PICKLE_DATA_FILENAME,"rb") as file:
|
|
X = pickle.load(file)
|
|
y = pickle.load(file)
|
|
|
|
i = random.randint(0,len(X) - 1)
|
|
#plt.imshow(X[0].reshape(28,28),cmap='gray',interpolation='none')
|
|
#plt.show()
|
|
|
|
#simple normalize
|
|
X = X / 255
|
|
|
|
y = np.array([int(i) for i in y])
|
|
Y = np.eye(10)[y]
|
|
|
|
gen:np.random.Generator = np.random.default_rng()
|
|
eta = 0.01
|
|
|
|
MiniBatchN = 100
|
|
|
|
weight1 = Variable(gen.normal(0,1,size=(784,10)))
|
|
bias1 = Variable(gen.normal(0,1,size=(10)))
|
|
|
|
accuracy_list = []
|
|
|
|
for iteration in range(0,100):
|
|
choiced_index = gen.choice(range(0,60000),MiniBatchN)
|
|
input_var = Variable(X[choiced_index])
|
|
U1 = (input_var @ weight1 + bias1)
|
|
J = SoftmaxWithNegativeLogLikelihood(U1,Y[choiced_index])
|
|
|
|
J.backprop(np.ones(()))
|
|
#update variable
|
|
weight1 = Variable(weight1.numpy() - (weight1.grad) * eta)
|
|
bias1 = Variable(bias1.numpy() - (bias1.grad) * eta)
|
|
if iteration % 5 == 0:
|
|
print(iteration,'iteration : avg(J) == ',np.average(J.numpy()))
|
|
s = J.softmax_numpy()
|
|
#print(Y[0:1000].shape)
|
|
s = np.round(s)
|
|
confusion = (np.transpose(Y[choiced_index])@s)
|
|
accuracy = np.trace(confusion).sum() / MiniBatchN
|
|
print('accuracy : ',accuracy * 100,'%')
|
|
accuracy_list.append(accuracy)
|
|
|
|
plt.title("accuracy")
|
|
plt.plot(np.linspace(0,len(accuracy_list),len(accuracy_list)),accuracy_list)
|
|
plt.show()
|
|
plt.title("confusion matrix")
|
|
plt.imshow(confusion,cmap='gray')
|
|
plt.show() |