sgd_hw/p4_simple_heatmap.py

11 lines
254 B
Python

from p4_model import *
import matplotlib.pyplot as plt
model = load_or_create_model([10])
heat = model.param[0][0].x.T
for i in range(0,10):
print(f'{i} index')
plt.imshow(heat[i].reshape(28,28),cmap='gray',interpolation='none')
plt.show()