import numpy as np from contextlib import contextmanager from typing import Generator, Dict, Union import io #only scalar gradient #op must be tree. 그래프 구현할려면, 위상정렬해서 순회해야하기 때문에 그렇게 하지 않음. class NonExistVarableError(ValueError): pass def make_mermaid_graph(result): with io.StringIO("") as graph: graph.write("graph TD\n") result.mermaid_graph(graph) graph.write(f"{id(result)}-->Result\n") return graph.getvalue() class OpTree: def __init__(self): super().__init__() def __matmul__(self,a): return MatMulOp(self,a) def __add__(self,a): return AddOp(self,a) @property def T(self): return TransposeOp(self) class MatMulOp(OpTree): def __init__(self,a,b): super().__init__() self.a = a self.b = b va = self.a.numpy() if isinstance(self.a,OpTree) else self.a vb = self.b.numpy() if isinstance(self.b,OpTree) else self.b self.v = va @ vb def __str__(self): return f"MatmulOp" def mermaid_graph(self,writer): if isinstance(self.a,OpTree): self.a.mermaid_graph(writer) writer.write(f'{id(self.a)}-->{id(self)}[MatmulOp]\n') if isinstance(self.b,OpTree): self.b.mermaid_graph(writer) writer.write(f'{id(self.b)}-->{id(self)}[MatmulOp]\n') def numpy(self): return self.v def backprop(self,seed): #a @ b a = self.a.numpy() if isinstance(self.a,OpTree) else self.a b = self.b.numpy() if isinstance(self.b,OpTree) else self.b if isinstance(self.a,OpTree): s = seed * np.transpose(b) if seed.shape == () else (seed) @ np.transpose(b) #print('seed : ', s) self.a.backprop((s)) if isinstance(self.b,OpTree): s = np.transpose(a) * seed if seed.shape == () else np.transpose(a) @ seed #print('seed : ', s) self.b.backprop(s) def matmul(a,b): return MatMulOp(a,b) class AddOp(OpTree): def __init__(self,a,b): super().__init__() self.a = a self.b = b va = self.a.numpy() if isinstance(self.a,OpTree) else self.a vb = self.b.numpy() if isinstance(self.b,OpTree) else self.b self.v = va + vb def __str__(self): return f"AddOp" def mermaid_graph(self,writer): if isinstance(self.a,OpTree): self.a.mermaid_graph(writer) writer.write(f'{id(self.a)}-->{id(self)}[AddOp]\n') if isinstance(self.b,OpTree): self.b.mermaid_graph(writer) writer.write(f'{id(self.b)}-->{id(self)}[AddOp]\n') def numpy(self): return self.v def backprop(self,seed): #borad_casted = self.a.shape != self.b.shape #np.ones((1,b.shape[1])) #a + b if isinstance(self.a,OpTree): self.a.backprop(seed) if isinstance(self.b,OpTree): self.b.backprop(seed) def addmul(a,b): return AddOp(a,b) class FunctionOp(OpTree): def __init__(self,f, f_grad, f_name, i): super().__init__() self.f = np.vectorize(f) self.f_grad = np.vectorize(f_grad) self.f_name = f_name self.i = i self.v = self.f(i.numpy()) def __str__(self): return f"Function{self.f_name}Op" def mermaid_graph(self,writer): self.i.mermaid_graph(writer) writer.write(f'{id(self.i)}-->{id(self)}[Function{self.f_name}Op]\n') def numpy(self): return self.v def backprop(self,seed): self.i.backprop(seed * (self.f_grad(self.i.numpy()))) class TransposeOp(OpTree): def __init__(self, i): super().__init__() self.i = i self.v = np.transpose(i.numpy()) def __str__(self): return f"TransposeOp" def mermaid_graph(self,writer): self.i.mermaid_graph(writer) writer.write(f'{id(self.i)}-->{id(self)}[TransposeOp]\n') def numpy(self): return self.v def backprop(self,seed): self.i.backprop(np.transpose(seed)) def transposemul(a): return TransposeOp(a) def relu(v): relu_f = lambda x: np.max([x,0]) relu_diff = lambda x: 1 if x > 0 else 0 return FunctionOp(relu_f,relu_diff,"Relu",v) class Variable(OpTree): def __init__(self,x): super().__init__() self.x = x self.grad = None def numpy(self): return self.x def mermaid_graph(self,writer): writer.write(f'{id(self)}["Variable{self.x.shape}"]\n') def backprop(self,seed): self.grad = seed """ input_var = Variable(np.array([[1],[2],[3]])) weight = Variable(np.array([[2,-1,1]])) v = relu(weight @ input_var) print(f"result : {v.numpy()}") v.backprop(np.ones(())) print(f"grad input : {input_var.grad}, w : {weight.grad}") """ #input_diff = Variable(np.array([[1.01],[2],[3]])) #v_diff = relu(weight @ input_diff) #print(f"diff 1 : {(np.sum(v_diff.numpy()) - v.numpy()) / 0.01}") #i -= grad * delta """ graph TD 2284612545696["Variable(1, 3)"] 2284612545696-->2284612624880[MatmulOp] 2284612544496["Variable(3, 2)"] 2284612544496-->2284612624880[MatmulOp] 2284612624880-->2284612625072[FunctionReluOp] 2284612625072-->2284612627856[MatmulOp] 2284612627856-->Result """