无法弄清楚梯度下降线性回归

问题描述

我目前正在研究梯度下降项目。

我选择了 nba stats 作为我的数据,所以我从篮球参考下载了 3Pts 数据和 pts 数据,并成功绘制了散点图。然而,结果似乎并不正确。

我的散点图向右上方(因为更多的 3 点通常意味着得分更多,所以这是有道理的)

但是我的梯度下降线要左上,我不知道怎么了。

import pandas as pd
import numpy as np
from sklearn import linear_model
from matplotlib import pyplot as plt


data = pd.read_csv('C:/Users/jeehw/Documents/FG3M_PTS_2021.csv')


X = data.iloc[:,1]
Y = data.iloc[:,2]

plt.figure(figsize=(8,6))
plt.xlabel('FG3M')                                  
plt.ylabel('PTS')
plt.scatter(X,Y)
plt.show()

m = 0
c = 0

L = 0.001
epochs = 200

n = float(len(X))

for i in range(len(X)):
Y_pred = m*X + c
m_Grad = (1/n) * sum(X * (Y_pred - Y))
c_Grad = (1/n) * sum(Y_pred - Y)


m = m - L* m_Grad
c = c - L* c_Grad

Y_pred = m*X + c

plt.scatter(X,Y)
plt.scatter(X,Y_pred)
plt.show()

解决方法

此代码中的一些内容实际上没有意义。您是否要从头开始进行回归?因为您确实导入了 scikit 学习但从不应用它。您可以参考此链接了解如何使用 scikit 学习回归 here。我也会考虑使用其他算法。

我相信这就是您在这里尝试做的:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import linear_model
from sklearn.metrics import mean_squared_error,r2_score
from matplotlib import pyplot as plt


#data = pd.read_csv('C:/Users/jeehw/Documents/FG3M_PTS_2021.csv')
raw_data = pd.read_html('https://www.basketball-reference.com/leagues/NBA_2021_totals.html')[0]
raw_data  = raw_data[raw_data['Rk'].ne('Rk')]

data = raw_data[['Player','3P','PTS']]
data[['3P','PTS']] = data[['3P','PTS']].astype(int)

X = data.iloc[:]['3P'].values
y = data.iloc[:]['PTS'].values

plt.figure(figsize=(8,6))
plt.xlabel('FG3M')                                  
plt.ylabel('PTS')
plt.scatter(X,y)

plt.xticks(np.arange(min(X),max(X)+1,20))
plt.yticks(np.arange(min(y),max(y)+1,100))
plt.show()


# Split data into test and Train
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.20,random_state=42)

# Create linear regression object
regr = linear_model.LinearRegression()

# Train the model using the training sets
regr.fit(X_train.reshape(-1,1),y_train)

# Make predictions using the testing set
y_pred = regr.predict(X_test.reshape(-1,1))



# The coefficients
print('Coefficients: \n',regr.coef_)
# The mean squared error
print('Mean squared error: %.2f'
      % mean_squared_error(y_test,y_pred))
# The coefficient of determination: 1 is perfect prediction
print('Coefficient of determination: %.2f'
      % r2_score(y_test,y_pred))

# Plot outputs
plt.scatter(X_test,y_test,color='black')
plt.plot(X_test,y_pred,color='red',linewidth=3)

plt.xticks(np.arange(min(X_test),max(X_test)+1,20))
plt.yticks(np.arange(min(y_pred),max(y_pred)+1,100))

plt.xlabel('FG3M')                                  
plt.ylabel('PTS')

plt.show()

enter image description here

enter image description here

里面有一些噪音。你有很多得分很高的球员,他们从不投三分,更不用说投进三分了。所以我会考虑先做大量的数据清理(也许只拿至少有 50 次 3 分尝试的球员?或者摆脱中心?另外,如果球员改变球队,他们可能会在数据集中几次每个团队都有他们的总数,所以那里有一些冗余......但我不会花时间清理它,因为它超出了问题的范围)。我还会测试其他机器学习算法。但是上面的代码至少应该让你开始玩。玩得开心!