实现自定义梯度下降函数的问题

问题描述

我正在使用 python 实现我自己的/自定义梯度下降算法,但我的算法返回的权重和偏差有 10 个值 (shape=(10,)) 但我的输入数据只有 1 列所以我期待它返回1个权重和1个偏差

代码

import numpy as np
import matplotlib.pyplot as plt

def SGD(X,y,learning_rate=0.01,max_iter=1000):
    w = np.random.randn(X.shape[1])
    b = np.random.randn(1,)
    print(w,b)
    n = len(X)
    loss_list = []

    for i in range(max_iter):
        y_pred = w*X + b

        Lw = -(2/n)*sum(X*(y - y_pred))
        Lb = -(2/n)*sum(y - y_pred)
        w = w - learning_rate*Lw
        b = b - learning_rate*Lb

        loss = np.square(np.subtract(y,y_pred)).mean()
        loss_list.append(loss)

        print(f"Epoch: {i},loss: {loss}")

    return w,b

x = list(range(1,11))
y = []
for i in x:
    y.append(i**2)

x,y = np.array(x).reshape(-1,1),np.array(y)
w,b = SGD(x,y)

print("\n\n\n\n")
print(w)
print(b)

最后一次迭代的损失:

Epoch: 999,loss: 0.11521764208740602

分别返回权重和偏差,

w: [0.00149535 0.00777379 0.01823786 0.03288755 0.05172286 0.07474381
 0.10195038 0.13334257 0.1689204  0.20868384] # giving 10 values

b: [ 0.98958964  3.94588026  8.87303129 15.77104274 24.63991461 35.47964689
 48.29023958 63.07169269 79.82400621 98.54718014] # giving 10 values

我不明白原因,这是怎么发生的? 谢谢!

解决方法

我认为这是因为您的 y 是一个 1d 行列表,但 y_pred 是一个 1xn 列列表,因此减去它们将为您提供一个 nxn 矩阵你不想要。解决方法是在像这样调用函数之前重塑 y

import numpy as np
import matplotlib.pyplot as plt

def SGD(X,y,learning_rate=0.01,max_iter=1000):
    w = np.random.randn(X.shape[1])
    b = np.random.randn(1,)
    print(w,b)
    n = len(X)
    loss_list = []

    for i in range(max_iter):
        y_pred = w*X + b

        Lw = -(2/n)*sum(X*(y - y_pred))
        Lb = -(2/n)*sum(y - y_pred)
        w = w - learning_rate*Lw
        b = b - learning_rate*Lb

        loss = np.square(np.subtract(y,y_pred)).mean()
        loss_list.append(loss)

        print(f"Epoch: {i},loss: {loss}")

    return w,b

x = list(range(1,11))
y = []
for i in x:
    y.append(i**2)

x,y = np.array(x).reshape(-1,1),np.array(y).reshape((-1,1)) # Change is here
w,b = SGD(x,y)

print("\n\n\n\n")
print(w)
print(b)

然后 w,b 是:

[10.94655101]
[-21.6278976]

分别