sgd_hw/p4.py

10 lines
303 B
Python
Raw Normal View History

2021-02-13 13:20:59 +09:00
from sklearn import datasets
import numpy as np
X, y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True, cache=True, as_frame= False)
print(X,y)
gen:np.random.Generator = np.random.default_rng()
input_var = Variable(X)
weight = Variable(gen.normal(100,784))
bias = Variable(np.array([1]))