simple softmax
This commit is contained in:
parent
ac5556f182
commit
3ec4704618
2
.gitignore
vendored
2
.gitignore
vendored
@ -139,3 +139,5 @@ dmypy.json
|
|||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
sgd_hw/
|
sgd_hw/
|
||||||
|
|
||||||
|
mnist.pickle
|
8
layer.py
8
layer.py
@ -188,8 +188,10 @@ class SoftmaxWithNegativeLogLikelihood(OpTree):
|
|||||||
def __init__(self, i, y):
|
def __init__(self, i, y):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.i = i
|
self.i = i
|
||||||
self.s = softmaxHelp(i)
|
self.s = softmaxHelp(i.numpy())
|
||||||
self.v = -y*self.s
|
self.y = y
|
||||||
|
self.v = -y*np.log(self.s)
|
||||||
|
self.v = np.sum(self.v,axis=self.v.ndim-1)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"SoftmaxWithNegativeLogLikelihoodOp"
|
return f"SoftmaxWithNegativeLogLikelihoodOp"
|
||||||
@ -205,7 +207,7 @@ class SoftmaxWithNegativeLogLikelihood(OpTree):
|
|||||||
return self.s
|
return self.s
|
||||||
|
|
||||||
def backprop(self,seed):
|
def backprop(self,seed):
|
||||||
self.i.backprop(seed * (self.s-y))
|
self.i.backprop(seed * (self.s-self.y))
|
||||||
|
|
||||||
class Variable(OpTree):
|
class Variable(OpTree):
|
||||||
def __init__(self,x):
|
def __init__(self,x):
|
||||||
|
68
p4.py
68
p4.py
@ -1,10 +1,68 @@
|
|||||||
from sklearn import datasets
|
from sklearn import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
X, y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True, cache=True, as_frame= False)
|
from layer import *
|
||||||
print(X,y)
|
import os
|
||||||
|
import pickle
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import random
|
||||||
|
|
||||||
|
matplotlib.use("TkAgg")
|
||||||
|
|
||||||
|
PICKLE_DATA_FILENAME = "mnist.pickle"
|
||||||
|
if not os.path.exists(PICKLE_DATA_FILENAME):
|
||||||
|
X, y = datasets.fetch_openml('mnist_784', return_X_y=True, cache=True, as_frame= False)
|
||||||
|
with open(PICKLE_DATA_FILENAME,"wb") as file:
|
||||||
|
pickle.dump(X,file)
|
||||||
|
pickle.dump(y,file)
|
||||||
|
else:
|
||||||
|
with open(PICKLE_DATA_FILENAME,"rb") as file:
|
||||||
|
X = pickle.load(file)
|
||||||
|
y = pickle.load(file)
|
||||||
|
|
||||||
|
i = random.randint(0,len(X) - 1)
|
||||||
|
#plt.imshow(X[0].reshape(28,28),cmap='gray',interpolation='none')
|
||||||
|
#plt.show()
|
||||||
|
|
||||||
|
#simple normalize
|
||||||
|
X = X / 255
|
||||||
|
|
||||||
|
y = np.array([int(i) for i in y])
|
||||||
|
Y = np.eye(10)[y]
|
||||||
|
|
||||||
gen:np.random.Generator = np.random.default_rng()
|
gen:np.random.Generator = np.random.default_rng()
|
||||||
|
eta = 0.01
|
||||||
|
|
||||||
input_var = Variable(X)
|
MiniBatchN = 100
|
||||||
weight = Variable(gen.normal(100,784))
|
|
||||||
bias = Variable(np.array([1]))
|
weight1 = Variable(gen.normal(0,1,size=(784,10)))
|
||||||
|
bias1 = Variable(gen.normal(0,1,size=(10)))
|
||||||
|
|
||||||
|
accuracy_list = []
|
||||||
|
|
||||||
|
for iteration in range(0,100):
|
||||||
|
choiced_index = gen.choice(range(0,60000),MiniBatchN)
|
||||||
|
input_var = Variable(X[choiced_index])
|
||||||
|
U1 = (input_var @ weight1 + bias1)
|
||||||
|
J = SoftmaxWithNegativeLogLikelihood(U1,Y[choiced_index])
|
||||||
|
|
||||||
|
J.backprop(np.ones(()))
|
||||||
|
#update variable
|
||||||
|
weight1 = Variable(weight1.numpy() - (weight1.grad) * eta)
|
||||||
|
bias1 = Variable(bias1.numpy() - (bias1.grad) * eta)
|
||||||
|
if iteration % 5 == 0:
|
||||||
|
print(iteration,'iteration : avg(J) == ',np.average(J.numpy()))
|
||||||
|
s = J.softmax_numpy()
|
||||||
|
#print(Y[0:1000].shape)
|
||||||
|
s = np.round(s)
|
||||||
|
confusion = (np.transpose(Y[choiced_index])@s)
|
||||||
|
accuracy = np.trace(confusion).sum() / MiniBatchN
|
||||||
|
print('accuracy : ',accuracy * 100,'%')
|
||||||
|
accuracy_list.append(accuracy)
|
||||||
|
|
||||||
|
plt.title("accuracy")
|
||||||
|
plt.plot(np.linspace(0,len(accuracy_list),len(accuracy_list)),accuracy_list)
|
||||||
|
plt.show()
|
||||||
|
plt.title("confusion matrix")
|
||||||
|
plt.imshow(confusion,cmap='gray')
|
||||||
|
plt.show()
|
Loading…
Reference in New Issue
Block a user