from p4_model import * import matplotlib.pyplot as plt model = load_or_create_model([10]) heat = model.param[0][0].x.T for i in range(0,10): print(f'{i} index') plt.imshow(heat[i].reshape(28,28),cmap='gray',interpolation='none') plt.show()