1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D
x_data = np.array([1.0, 2.0, 3.0]) y_data = np.array([5.0, 8.0, 11.0])
def forward(x): return x * w + b
def loss(x, y): y_pred = forward(x) return (y_pred - y) ** 2
W = np.arange(0.0, 4.1, 0.1) B = np.arange(0.0, 4.1, 0.1) w, b = np.meshgrid(W, B)
loss_sum = np.zeros_like(w)
for x_val, y_val in zip(x_data, y_data): loss_val = loss(x_val,y_val) loss_sum += loss_val
mse = loss_sum / len(x_data)
fig = plt.figure() ax = fig.add_subplot(111,projection='3d') ax.plot_surface(w,b,mse,cmap='viridis') ax.set_xlabel('w') ax.set_ylabel('b') ax.set_zlabel('MSE') plt.show()
|