问题描述
我正在尝试使用 softmax 来解决时尚 mnist。但是,成本函数总是在增加,并且运行训练需要花费大量时间。请帮助我改进这些代码行。我不知道我写的向量维度或函数是否正确。 x.shape 是 (784,60000),y.shape 是 (60000,)。非常感谢。
import numpy as np
num_px = x.shape[1]
x = np.array(x)
x_test = np.array(x_test)
x = (x.reshape(x.shape[0],-1).T) / 255
x_test = (x_test.reshape(x_test.shape[0],-1).T) / 255
print (y.shape)
print ("x's shape: " + str(x.shape))
print ("x_test's shape: " + str(x_test.shape))
def softmax(z):
z -= np.max(z)
sm = (np.exp(z).T / np.sum(np.exp(z),axis=1))
return sm
def initialize(dim1):
w = np.zeros((dim1,1))
b = 0
return w,b
def propagate(w,b,X,Y):
m = X.shape[1]
A = softmax((np.dot(w.T,X) + b))
cost = (-1 / m) * np.sum(Y * np.log(A))
# backwar prop
dw = (1 / m) * np.dot(X,(A - Y))
db = (1 / m) * np.sum(A - Y)
cost = np.squeeze(cost)
grads = {"dw": dw,"db": db}
return grads,cost
def optimize(w,Y,num_iters,alpha,print_cost=False):
costs = []
for i in range(num_iters):
grads,cost = propagate(w,Y)
dw = grads["dw"]
db = grads["db"]
w = w - alpha * dw
b = b - alpha * db
if i % 50 == 0:
costs.append(cost)
# Print the cost every 100 training examples
if print_cost and i % 50 == 0:
print("Cost after iteration %i: %f" % (i,cost))
params = {"w": w,"b": b}
grads = {"dw": dw,"db": db}
return params,grads,costs
def predict(w,X):
y_pred = np.argmax(softmax((np.dot(w.T,X) + b)),axis=1)
return y_pred
def model_LR(X_train,Y_train,test_x,test_y,print_cost):
w,b = initialize(X_train.shape[0])
parameters,costs = optimize(w,X_train,print_cost)
w = parameters["w"]
b = parameters["b"]
y_prediction_train = predict(w,X_train)
y_prediction_test = predict(w,test_x)
print("Train accuracy: {} %",sum(y_prediction_train == Y_train[0]) / (float(len(Y_train))) * 100)
print("Test accuracy: {} %",sum(y_prediction_test == test_y[0]) / (float(len(test_y))) * 100)
d = {"costs": costs,"w": w,"b": b,"learning_rate": alpha,"num_iterations": num_iters}
# Plot learning curve (with costs)
costs = np.squeeze(d['costs'])
plt.plot(costs)
plt.ylabel('cost')
plt.xlabel('iterations (per hundreds)')
plt.title("Learning rate =" + str(d["learning_rate"]))
plt.plot()
plt.show()
plt.close()
return d
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)