通过Pytorch无法更新线性回归中的偏差

问题描述

我想使用PyTorch构建线性回归来解决一个简单的任务。一维输出只是10维输入的平均值。奇怪的是,在训练过程中,权重将收敛到0.1,而偏差保持与认值相同。预期偏差应为0。代码和结果如下所示。

from torch import nn
from torch.autograd import Variable
from torch.optim.lr_scheduler import LambdaLR,ReduceLROnPlateau
from torch.utils.data import DataLoader,Dataset
from matplotlib import pyplot
import torch as th
import torch.nn.functional as F
import pandas as pd
import numpy as np
import sys

# Linear Regression Model
class linearRegression(nn.Module):
    def __init__(self):
        super(linearRegression,self).__init__()
        self.linear = nn.Linear(10,1)  # input and output is 1 dimension

    def forward(self,x):
        out = self.linear(x)
        return out

#th.manual_seed(2020)
data = pd.read_csv('./train.csv',dtype=np.float64).values[:,1:]
train_data_label = data[:,0].reshape(-1,1)
train_data = np.delete(data,axis=1)

x_train_tensor = th.tensor(train_data[0:9500,:]).float().to(device)
y_train_tensor = th.tensor(train_data_label[0:9500,:]).float().to(device)
x_valid_tensor = th.tensor(train_data[9500:,:]).float().to(device)
y_valid_tensor = th.tensor(train_data_label[9500:,:]).float().to(device)

model = linearRegression().float() # Transform the precision of model
model = model.to(device)

criterion = nn.MSELoss()
optimizer = th.optim.SGD(model.parameters(),lr=1e-8)

train_loss_history = []
valid_loss_history = []

print("initial weights: ",model.linear.weight.data," bias: ",model.linear.bias.data)
# Start training
num_epochs = 500
for epoch in range(num_epochs):
    inputs = x_train_tensor
    target = y_train_tensor

    # Forward
    out = model(inputs)
    loss = criterion(out,target)
    # Backward
    optimizer.zero_grad() # Set gradients of all parameters to zero
    loss.backward()
    optimizer.step()
    train_loss_history.append(loss.item())

    # For validation 
    model.eval() # Use pretrained weights and bias
    with th.no_grad(): # No calculation of gradient during validation
      predict = model(x_valid_tensor)
      valid_loss = criterion(predict,y_valid_tensor)
      valid_loss_history.append(valid_loss.item())

    if (epoch+1) % 1 == 0:
      print(f'Epoch[{epoch+1}/{num_epochs}],bias: {model.linear.bias.data},weight: {model.linear.weight.data}')

print("weight: ",model.linear.weight.data)
print("bias: ",model.linear.bias.data)

initial weights:  tensor([[ 0.2636,-0.0230,-0.2399,0.3074,0.0735,0.2014,-0.2753,0.1922,-0.1379,0.2794]],device='cuda:0')  bias:  tensor([0.2830],device='cuda:0')
Epoch[1/500],bias: tensor([0.2830],device='cuda:0'),weight: tensor([[ 0.2604,-0.0206,-0.2334,0.3034,0.0741,0.1993,-0.2677,0.1907,-0.1335,0.2760]],device='cuda:0')
Epoch[2/500],weight: tensor([[ 0.2572,-0.0183,-0.2269,0.2995,0.0747,0.1973,-0.2603,0.1891,-0.1292,0.2726]],device='cuda:0')
Epoch[3/500],weight: tensor([[ 0.2540,-0.0160,-0.2206,0.2956,0.0753,0.1953,-0.2530,0.1876,-0.1250,0.2693]],device='cuda:0')
Epoch[4/500],weight: tensor([[ 0.2510,-0.0138,-0.2144,0.2918,0.0759,0.1934,-0.2458,0.1861,-0.1208,0.2660]],device='cuda:0')
Epoch[5/500],weight: tensor([[ 0.2480,-0.0116,-0.2084,0.2881,0.0765,0.1915,-0.2389,0.1846,-0.1167,0.2628]],device='cuda:0')
Epoch[6/500],weight: tensor([[ 0.2450,-0.0094,-0.2024,0.2845,0.0770,0.1896,-0.2320,0.1832,-0.1128,0.2597]],device='cuda:0')
Epoch[7/500],weight: tensor([[ 0.2421,-0.0073,-0.1966,0.2809,0.0776,0.1878,-0.2253,0.1818,-0.1088,0.2566]],device='cuda:0')
Epoch[8/500],weight: tensor([[ 0.2393,-0.0052,-0.1908,0.2774,0.0781,0.1860,-0.2187,0.1804,-0.1050,0.2536]],device='cuda:0')
Epoch[9/500],weight: tensor([[ 0.2365,-0.0032,-0.1852,0.2740,0.0786,0.1842,-0.2123,0.1790,-0.1012,0.2506]],device='cuda:0')
Epoch[10/500],weight: tensor([[ 0.2338,-0.0012,-0.1797,0.2706,0.0791,0.1825,-0.2060,0.1777,-0.0975,0.2477]],device='cuda:0')
Epoch[11/500],weight: tensor([[ 0.2311,0.0008,-0.1743,0.2673,0.0796,0.1808,-0.1998,0.1763,-0.0939,0.2449]],device='cuda:0')
Epoch[12/500],weight: tensor([[ 0.2285,0.0027,-0.1690,0.2641,0.0801,0.1792,-0.1937,0.1750,-0.0903,0.2421]],device='cuda:0')
Epoch[13/500],weight: tensor([[ 0.2260,0.0046,-0.1638,0.2609,0.0806,0.1776,-0.1878,0.1737,-0.0868,0.2393]],device='cuda:0')
Epoch[14/500],weight: tensor([[ 0.2235,0.0064,-0.1587,0.2578,0.0810,0.1760,-0.1820,0.1725,-0.0833,0.2367]],device='cuda:0')
Epoch[15/500],weight: tensor([[ 0.2210,0.0082,-0.1537,0.2547,0.0815,0.1744,-0.1763,0.1713,-0.0800,0.2340]],device='cuda:0')
Epoch[16/500],weight: tensor([[ 0.2186,0.0100,-0.1488,0.2517,0.0819,0.1729,-0.1707,0.1700,-0.0767,0.2314]],device='cuda:0')
Epoch[17/500],weight: tensor([[ 0.2162,0.0117,-0.1440,0.2488,0.0824,0.1714,-0.1652,0.1688,-0.0734,0.2289]],device='cuda:0')
Epoch[18/500],weight: tensor([[ 0.2139,0.0134,-0.1393,0.2459,0.0828,-0.1599,0.1677,-0.0702,0.2264]],device='cuda:0')
Epoch[19/500],weight: tensor([[ 0.2116,0.0151,-0.1347,0.2431,0.0832,0.1685,-0.1546,0.1665,-0.0671,0.2240]],device='cuda:0')
Epoch[20/500],weight: tensor([[ 0.2094,0.0168,-0.1301,0.2403,0.0836,0.1671,-0.1495,0.1654,-0.0640,0.2216]],device='cuda:0')
Epoch[21/500],weight: tensor([[ 0.2072,0.0184,-0.1257,0.2376,0.0840,0.1658,-0.1444,0.1642,-0.0610,0.2193]],device='cuda:0')
Epoch[22/500],weight: tensor([[ 0.2051,0.0199,-0.1213,0.2349,0.0844,0.1644,-0.1395,0.1631,-0.0580,0.2170]],device='cuda:0')
Epoch[23/500],weight: tensor([[ 0.2030,0.0215,-0.1171,0.2323,0.0847,-0.1346,0.1621,-0.0551,0.2147]],device='cuda:0')
Epoch[24/500],weight: tensor([[ 0.2010,0.0230,-0.1129,0.2297,0.0851,0.1618,-0.1299,0.1610,-0.0522,0.2125]],device='cuda:0')
Epoch[25/500],weight: tensor([[ 0.1989,0.0245,0.2272,0.0854,0.1606,-0.1253,0.1599,-0.0494,0.2103]],device='cuda:0')
Epoch[26/500],weight: tensor([[ 0.1970,0.0260,-0.1047,0.2248,0.0858,0.1593,-0.1207,0.1589,-0.0466,0.2082]],device='cuda:0')
Epoch[27/500],weight: tensor([[ 0.1950,0.0274,-0.1008,0.2223,0.0861,0.1581,-0.1162,0.1579,-0.0439,0.2061]],device='cuda:0')
Epoch[28/500],weight: tensor([[ 0.1932,0.0288,-0.0969,0.2200,0.0865,0.1569,-0.1119,-0.0413,0.2041]],device='cuda:0')
Epoch[29/500],weight: tensor([[ 0.1913,0.0302,-0.0931,0.2177,0.0868,0.1558,-0.1076,0.1559,-0.0387,0.2021]],device='cuda:0')
Epoch[30/500],weight: tensor([[ 0.1895,0.0315,-0.0894,0.2154,0.0871,0.1546,-0.1034,0.1550,-0.0361,0.2001]],device='cuda:0')
Epoch[31/500],weight: tensor([[ 0.1877,0.0328,-0.0857,0.2131,0.0874,0.1535,-0.0993,0.1540,-0.0336,0.1982]],device='cuda:0')
Epoch[32/500],weight: tensor([[ 0.1860,0.0341,-0.0821,0.2110,0.0877,0.1524,-0.0953,0.1531,-0.0311,0.1963]],device='cuda:0')
Epoch[33/500],weight: tensor([[ 0.1842,0.0354,-0.0786,0.2088,0.0880,0.1514,-0.0913,0.1522,-0.0287,0.1944]],device='cuda:0')
Epoch[34/500],weight: tensor([[ 0.1826,0.0367,-0.0752,0.2067,0.0883,0.1503,-0.0875,0.1513,-0.0263,0.1926]],device='cuda:0')
Epoch[35/500],weight: tensor([[ 0.1809,0.0379,-0.0718,0.2046,0.0886,0.1493,-0.0837,0.1504,-0.0240,0.1908]],device='cuda:0')
Epoch[36/500],weight: tensor([[ 0.1793,0.0391,-0.0685,0.2026,0.0889,0.1483,0.1495,-0.0217,0.1891]],device='cuda:0')
Epoch[37/500],weight: tensor([[ 0.1777,0.0403,-0.0652,0.2006,0.0891,0.1473,-0.0763,0.1487,-0.0195,0.1874]],device='cuda:0')
Epoch[38/500],weight: tensor([[ 0.1762,0.0414,-0.0620,0.1987,0.0894,0.1463,-0.0728,0.1479,-0.0173,0.1857]],device='cuda:0')
Epoch[39/500],weight: tensor([[ 0.1747,0.0425,-0.0589,0.1968,0.0896,0.1454,-0.0693,0.1470,-0.0151,0.1840]],device='cuda:0')
Epoch[40/500],weight: tensor([[ 0.1732,0.0437,-0.0558,0.1949,0.0899,0.1445,-0.0659,0.1462,-0.0130,0.1824]],device='cuda:0')
Epoch[41/500],weight: tensor([[ 0.1717,0.0447,-0.0528,0.1930,0.0901,0.1436,-0.0625,-0.0109,0.1808]],device='cuda:0')
Epoch[42/500],weight: tensor([[ 0.1703,0.0458,-0.0499,0.1912,0.0904,0.1427,-0.0592,0.1446,-0.0088,0.1793]],device='cuda:0')
Epoch[43/500],weight: tensor([[ 0.1689,0.0469,-0.0470,0.1895,0.0906,0.1418,-0.0560,0.1439,-0.0068,0.1778]],device='cuda:0')
Epoch[44/500],weight: tensor([[ 0.1675,0.0479,-0.0441,0.1877,0.0908,0.1409,-0.0529,0.1431,-0.0049,0.1763]],device='cuda:0')
Epoch[45/500],weight: tensor([[ 0.1662,0.0489,-0.0414,0.0911,0.1401,-0.0498,0.1424,-0.0029,0.1748]],device='cuda:0')
Epoch[46/500],weight: tensor([[ 0.1648,0.0499,-0.0386,0.1844,0.0913,0.1393,-0.0468,0.1416,-0.0010,0.1733]],device='cuda:0')
Epoch[47/500],weight: tensor([[ 0.1636,0.0508,-0.0360,0.1828,0.0915,0.1385,-0.0438,0.1719]],device='cuda:0')
Epoch[48/500],weight: tensor([[ 0.1623,0.0518,-0.0333,0.1811,0.0917,0.1377,-0.0409,0.1402,0.1706]],device='cuda:0')
Epoch[49/500],weight: tensor([[ 0.1611,0.0527,-0.0308,0.1796,0.0919,0.1369,-0.0380,0.1395,0.0045,0.1692]],device='cuda:0')
Epoch[50/500],weight: tensor([[ 0.1598,0.0536,-0.0282,0.1780,0.0921,0.1362,-0.0353,0.1388,0.0062,0.1679]],device='cuda:0')
Epoch[51/500],weight: tensor([[ 0.1586,0.0545,-0.0258,0.1765,0.0923,0.1354,-0.0325,0.1382,0.0080,0.1666]],device='cuda:0')
Epoch[52/500],weight: tensor([[ 0.1575,0.0554,-0.0233,0.0925,0.1347,-0.0299,0.1375,0.0097,0.1653]],device='cuda:0')
Epoch[53/500],weight: tensor([[ 0.1563,0.0563,-0.0210,0.1736,0.0927,0.1340,-0.0272,0.0113,0.1640]],device='cuda:0')
Epoch[54/500],weight: tensor([[ 0.1552,0.0571,-0.0186,0.1722,0.0928,0.1333,-0.0247,0.0130,0.1628]],device='cuda:0')
Epoch[55/500],weight: tensor([[ 0.1541,0.0579,-0.0163,0.1708,0.0930,0.1326,-0.0221,0.1356,0.0146,0.1616]],device='cuda:0')
Epoch[56/500],weight: tensor([[ 0.1530,0.0587,-0.0141,0.1694,0.0932,0.1320,-0.0197,0.1350,0.0162,0.1604]],device='cuda:0')
Epoch[57/500],weight: tensor([[ 0.1520,0.0595,-0.0119,0.1681,0.0934,0.1313,0.1344,0.0177,0.1592]],device='cuda:0')
Epoch[58/500],weight: tensor([[ 0.1509,0.0603,-0.0097,0.1667,0.0935,0.1307,-0.0149,0.1338,0.0192,0.1581]],device='cuda:0')
Epoch[59/500],weight: tensor([[ 0.1499,0.0611,-0.0076,0.0937,0.1301,-0.0126,0.1332,0.0207,0.1570]],device='cuda:0')
Epoch[60/500],weight: tensor([[ 0.1489,0.0618,-0.0055,0.0938,0.1294,-0.0103,0.0222,0.1559]],device='cuda:0')
Epoch[61/500],weight: tensor([[ 0.1480,0.0626,-0.0035,0.1629,0.0940,0.1288,-0.0081,0.0236,0.1548]],device='cuda:0')
Epoch[62/500],weight: tensor([[ 0.1470,0.0633,-0.0015,0.1617,0.0941,0.1283,-0.0059,0.1315,0.0251,0.1537]],device='cuda:0')
Epoch[63/500],weight: tensor([[ 0.1461,0.0640,0.0005,0.1605,0.0943,0.1277,-0.0038,0.1309,0.0264,0.1527]],device='cuda:0')
Epoch[64/500],weight: tensor([[ 0.1451,0.0647,0.0024,0.1594,0.0944,0.1271,-0.0017,0.1304,0.0278,0.1517]],device='cuda:0')
Epoch[65/500],weight: tensor([[0.1442,0.0654,0.0043,0.1582,0.0946,0.1266,0.0004,0.1299,0.0291,0.1507]],device='cuda:0')
Epoch[66/500],weight: tensor([[0.1434,0.0661,0.0061,0.1571,0.0947,0.1260,0.0304,0.1497]],device='cuda:0')
Epoch[67/500],weight: tensor([[0.1425,0.0667,0.0079,0.1560,0.0948,0.1255,0.0044,0.1289,0.0317,0.1488]],device='cuda:0')
Epoch[68/500],weight: tensor([[0.1417,0.0673,0.1549,0.0949,0.1250,0.0063,0.1284,0.0330,0.1478]],device='cuda:0')
Epoch[69/500],weight: tensor([[0.1408,0.0680,0.0114,0.1538,0.0951,0.1245,0.1279,0.0342,0.1469]],device='cuda:0')
Epoch[70/500],weight: tensor([[0.1400,0.0686,0.0132,0.1528,0.0952,0.1240,0.1274,0.0355,0.1460]],device='cuda:0')
Epoch[71/500],weight: tensor([[0.1392,0.0692,0.0148,0.1518,0.0953,0.1235,0.0118,0.1269,0.1451]],device='cuda:0')
Epoch[72/500],weight: tensor([[0.1384,0.0698,0.0165,0.1508,0.0954,0.1230,0.0136,0.1264,0.0378,0.1442]],device='cuda:0')
Epoch[73/500],weight: tensor([[0.1377,0.0704,0.0181,0.1498,0.0955,0.1225,0.0154,0.0390,0.1434]],device='cuda:0')
Epoch[74/500],weight: tensor([[0.1369,0.0709,0.0197,0.1488,0.0957,0.1221,0.0171,0.0401,0.1426]],device='cuda:0')
Epoch[75/500],weight: tensor([[0.1362,0.0715,0.0212,0.0958,0.1216,0.0188,0.1251,0.0412,0.1417]],device='cuda:0')
Epoch[76/500],weight: tensor([[0.1355,0.0721,0.0227,0.1469,0.0959,0.1212,0.0204,0.1246,0.0423,0.1409]],device='cuda:0')
Epoch[77/500],weight: tensor([[0.1348,0.0726,0.0242,0.1460,0.0960,0.1207,0.0220,0.1242,0.0434,0.1402]],device='cuda:0')
Epoch[78/500],weight: tensor([[0.1341,0.0731,0.0257,0.1451,0.0961,0.1203,0.1238,0.0444,0.1394]],device='cuda:0')
Epoch[79/500],weight: tensor([[0.1334,0.0736,0.0271,0.1443,0.0962,0.1199,0.1234,0.0454,0.1386]],device='cuda:0')
Epoch[80/500],weight: tensor([[0.1327,0.0742,0.0285,0.1434,0.0963,0.1195,0.0266,0.0465,0.1379]],device='cuda:0')
Epoch[81/500],weight: tensor([[0.1321,0.0746,0.0299,0.1426,0.0964,0.1191,0.0281,0.1226,0.0475,0.1371]],device='cuda:0')
Epoch[82/500],weight: tensor([[0.1314,0.0751,0.0313,0.0965,0.1187,0.0296,0.1222,0.0484,0.1364]],device='cuda:0')
Epoch[83/500],weight: tensor([[0.1308,0.0756,0.0326,0.1183,0.0310,0.1218,0.0494,0.1357]],device='cuda:0')
Epoch[84/500],weight: tensor([[0.1302,0.0761,0.0339,0.0966,0.1180,0.0324,0.1214,0.0503,0.1350]],device='cuda:0')
Epoch[85/500],weight: tensor([[0.1296,0.0352,0.1394,0.0967,0.1176,0.0337,0.1210,0.0512,0.1344]],device='cuda:0')
Epoch[86/500],weight: tensor([[0.1290,0.0364,0.1386,0.0968,0.1172,0.0351,0.0521,0.1337]],device='cuda:0')
Epoch[87/500],weight: tensor([[0.1284,0.0774,0.0376,0.1379,0.0969,0.1169,0.0530,0.1331]],device='cuda:0')
Epoch[88/500],weight: tensor([[0.1279,0.0779,0.0388,0.1371,0.0970,0.1165,0.0377,0.1200,0.0539,0.1324]],device='cuda:0')
Epoch[89/500],weight: tensor([[0.1273,0.0783,0.0400,0.1364,0.1162,0.0389,0.1196,0.0548,0.1318]],device='cuda:0')
Epoch[90/500],weight: tensor([[0.1268,0.0787,0.1357,0.0971,0.1159,0.1193,0.0556,0.1312]],device='cuda:0')
Epoch[91/500],weight: tensor([[0.1262,0.0972,0.1155,0.1189,0.0564,0.1306]],device='cuda:0')
Epoch[92/500],weight: tensor([[0.1257,0.0795,0.1343,0.0973,0.1152,0.1186,0.0572,0.1300]],device='cuda:0')
Epoch[93/500],weight: tensor([[0.1252,0.0799,0.0445,0.1337,0.1149,0.0580,0.1294]],device='cuda:0')
Epoch[94/500],weight: tensor([[0.1247,0.0803,0.0456,0.1330,0.0974,0.1146,0.0448,0.0588,0.1289]],device='cuda:0')
Epoch[95/500],weight: tensor([[0.1242,0.0807,0.0466,0.1324,0.0975,0.1143,0.0459,0.0596,0.1283]],device='cuda:0')
Epoch[96/500],weight: tensor([[0.1237,0.0811,0.0477,0.1318,0.1140,0.0470,0.1173,0.1278]],device='cuda:0')
Epoch[97/500],weight: tensor([[0.1232,0.0814,0.0487,0.1311,0.0976,0.1137,0.0481,0.1170,0.0610,0.1272]],device='cuda:0')
Epoch[98/500],weight: tensor([[0.1228,0.0818,0.0497,0.1305,0.0977,0.1135,0.0491,0.1167,0.1267]],device='cuda:0')
Epoch[99/500],weight: tensor([[0.1223,0.0821,0.0506,0.1300,0.1132,0.0502,0.1164,0.0625,0.1262]],device='cuda:0')
Epoch[100/500],weight: tensor([[0.1219,0.0825,0.0516,0.0978,0.1129,0.1161,0.0632,0.1257]],device='cuda:0')
...

Epoch[480/500],weight: tensor([[0.1000,0.1000,0.1000]],device='cuda:0')
Epoch[481/500],device='cuda:0')
Epoch[482/500],device='cuda:0')
Epoch[483/500],device='cuda:0')
Epoch[484/500],device='cuda:0')
Epoch[485/500],device='cuda:0')
Epoch[486/500],device='cuda:0')
Epoch[487/500],device='cuda:0')
Epoch[488/500],device='cuda:0')
Epoch[489/500],device='cuda:0')
Epoch[490/500],device='cuda:0')
Epoch[491/500],device='cuda:0')
Epoch[492/500],device='cuda:0')
Epoch[493/500],device='cuda:0')
Epoch[494/500],device='cuda:0')
Epoch[495/500],device='cuda:0')
Epoch[496/500],device='cuda:0')
Epoch[497/500],device='cuda:0')
Epoch[498/500],device='cuda:0')
Epoch[499/500],device='cuda:0')
Epoch[500/500],device='cuda:0')
weight:  tensor([[0.1000,device='cuda:0')
bias:  tensor([0.2830],device='cuda:0')

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)