sgd_hw/layer.py

192 lines
5.2 KiB
Python
Raw Normal View History

2021-02-12 15:50:20 +09:00
import numpy as np
from contextlib import contextmanager
from typing import Generator, Dict, Union
import io
#only scalar gradient
#op must be tree. 그래프 구현할려면, 위상정렬해서 순회해야하기 때문에 그렇게 하지 않음.
class NonExistVarableError(ValueError):
pass
def make_mermaid_graph(result):
with io.StringIO("") as graph:
graph.write("graph TD\n")
result.mermaid_graph(graph)
graph.write(f"{id(result)}-->Result\n")
return graph.getvalue()
class OpTree:
def __init__(self):
super().__init__()
def __matmul__(self,a):
return MatMulOp(self,a)
def __add__(self,a):
return AddOp(self,a)
@property
def T(self):
return TransposeOp(self)
class MatMulOp(OpTree):
def __init__(self,a,b):
super().__init__()
self.a = a
self.b = b
va = self.a.numpy() if isinstance(self.a,OpTree) else self.a
vb = self.b.numpy() if isinstance(self.b,OpTree) else self.b
self.v = va @ vb
def __str__(self):
return f"MatmulOp"
def mermaid_graph(self,writer):
if isinstance(self.a,OpTree):
self.a.mermaid_graph(writer)
writer.write(f'{id(self.a)}-->{id(self)}[MatmulOp]\n')
if isinstance(self.b,OpTree):
self.b.mermaid_graph(writer)
writer.write(f'{id(self.b)}-->{id(self)}[MatmulOp]\n')
def numpy(self):
return self.v
def backprop(self,seed):
#a @ b
a = self.a.numpy() if isinstance(self.a,OpTree) else self.a
b = self.b.numpy() if isinstance(self.b,OpTree) else self.b
if isinstance(self.a,OpTree):
s = seed * np.transpose(b) if seed.shape == () else (seed) @ np.transpose(b)
#print('seed : ', s)
self.a.backprop((s))
if isinstance(self.b,OpTree):
s = np.transpose(a) * seed if seed.shape == () else np.transpose(a) @ seed
#print('seed : ', s)
self.b.backprop(s)
def matmul(a,b):
return MatMulOp(a,b)
class AddOp(OpTree):
def __init__(self,a,b):
super().__init__()
self.a = a
self.b = b
va = self.a.numpy() if isinstance(self.a,OpTree) else self.a
vb = self.b.numpy() if isinstance(self.b,OpTree) else self.b
self.v = va + vb
def __str__(self):
return f"AddOp"
def mermaid_graph(self,writer):
if isinstance(self.a,OpTree):
self.a.mermaid_graph(writer)
writer.write(f'{id(self.a)}-->{id(self)}[AddOp]\n')
if isinstance(self.b,OpTree):
self.b.mermaid_graph(writer)
writer.write(f'{id(self.b)}-->{id(self)}[AddOp]\n')
def numpy(self):
return self.v
def backprop(self,seed):
#borad_casted = self.a.shape != self.b.shape
#np.ones((1,b.shape[1]))
#a + b
if isinstance(self.a,OpTree):
self.a.backprop(seed)
if isinstance(self.b,OpTree):
self.b.backprop(seed)
def addmul(a,b):
return AddOp(a,b)
class FunctionOp(OpTree):
def __init__(self,f, f_grad, f_name, i):
super().__init__()
self.f = np.vectorize(f)
self.f_grad = np.vectorize(f_grad)
self.f_name = f_name
self.i = i
self.v = self.f(i.numpy())
def __str__(self):
return f"Function{self.f_name}Op"
def mermaid_graph(self,writer):
self.i.mermaid_graph(writer)
writer.write(f'{id(self.i)}-->{id(self)}[Function{self.f_name}Op]\n')
def numpy(self):
return self.v
def backprop(self,seed):
self.i.backprop(seed * (self.f_grad(self.i.numpy())))
class TransposeOp(OpTree):
def __init__(self, i):
super().__init__()
self.i = i
self.v = np.transpose(i.numpy())
def __str__(self):
return f"TransposeOp"
def mermaid_graph(self,writer):
self.i.mermaid_graph(writer)
writer.write(f'{id(self.i)}-->{id(self)}[TransposeOp]\n')
def numpy(self):
return self.v
def backprop(self,seed):
self.i.backprop(np.transpose(seed))
def transposemul(a):
return TransposeOp(a)
def relu(v):
relu_f = lambda x: np.max([x,0])
relu_diff = lambda x: 1 if x > 0 else 0
return FunctionOp(relu_f,relu_diff,"Relu",v)
class Variable(OpTree):
def __init__(self,x):
super().__init__()
self.x = x
self.grad = None
def numpy(self):
return self.x
def mermaid_graph(self,writer):
writer.write(f'{id(self)}["Variable{self.x.shape}"]\n')
def backprop(self,seed):
self.grad = seed
"""
input_var = Variable(np.array([[1],[2],[3]]))
weight = Variable(np.array([[2,-1,1]]))
v = relu(weight @ input_var)
print(f"result : {v.numpy()}")
v.backprop(np.ones(()))
print(f"grad input : {input_var.grad}, w : {weight.grad}")
"""
#input_diff = Variable(np.array([[1.01],[2],[3]]))
#v_diff = relu(weight @ input_diff)
#print(f"diff 1 : {(np.sum(v_diff.numpy()) - v.numpy()) / 0.01}")
#i -= grad * delta
"""
graph TD
2284612545696["Variable(1, 3)"]
2284612545696-->2284612624880[MatmulOp]
2284612544496["Variable(3, 2)"]
2284612544496-->2284612624880[MatmulOp]
2284612624880-->2284612625072[FunctionReluOp]
2284612625072-->2284612627856[MatmulOp]
2284612627856-->Result
"""