Pytorch 模型无法在多批次或大批次上学习

问题描述

背景信息

我的数据是 (11423,2,10,34) 的形状。我正在处理体育数据,其中每个实例都是 两支 球队的对决,每支球队有 10 名球员,每位球员的统计数据为 34我有来自 11423 场比赛的数据。

我想使用一个 10x1 的 1D 卷积滤波器,在两个通道上所有 10 个玩家的 34 个统计数据中向右滑动。 我的目标变量是游戏的点差。

我试图粗略地重新创建这个架构: CNN Architecture

我的问题

我可以在最多 39 个训练示例的单批上进行训练并接近零损失。 然而,一旦我增加到 40 个训练样本,我的模型就不会学习,因为每个时期的损失都保持不变,并且输出都是相同的数字(我发现这是所有基本事实的平均值目标)。 此外,如果我尝试将一组 32 个训练示例分成两批 16 个,我的模型将无法学习。我遇到了同样的问题,即损失不会减少并为每个训练示例预测相同的输出。什么可能导致这种行为?

示例

torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

batch_size = 39

X_train_tensor = torch.from_numpy(X_train).float()[:batch_size]
y_train_tensor = torch.from_numpy(y_train).float()[:batch_size]

print("X_train_tensor shape:",X_train_tensor.shape,"\nX_test_tensor shape:",y_test_tensor.shape,"\n")

# Model Architecture
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        
        self.conv1 = nn.Conv1d(2,kernel_size=(10,1),stride=1)       
        self.fc1 = nn.Linear(72,32)
        self.fc2 = nn.Linear(32,1)        
        
    def forward(self,x):
        
        x = F.tanh(self.conv1(x))
        x = x.reshape(-1,72)
        x = F.tanh(self.fc1(x))
        x = self.fc2(x)        
        return x
    

# Instantiating Model
net = Net()
loss_function = nn.MSELoss()
optimizer = optim.Adam(net.parameters(),lr=0.1)

epochs = 300

# Training Loop
for e in range(epochs):
    net.train()
    
    for i in range(0,len(X_train_tensor),batch_size):
        batch_X = X_train_tensor[i:i+batch_size]
        batch_y = y_train_tensor[i:i+batch_size]
        
        net.zero_grad()

        output = net(batch_X)


        loss= loss_function(output,batch_y)

        loss.backward()
        optimizer.step()
    if e %10 == 0:
        print("loss:",loss)
        
print("model predictions:\n",output.detach().numpy().flatten(),"\nGround Truth:\n",batch_y.detach().numpy().flatten()) 

Loss and Outputs when batch_size = 39

将 batch_size 更改为 40 即可解决

Loss and Outputs when batch_size = 40

注意模型对所有示例的预测和输出为 2.1,这是基本事实的平均值。

我尝试改变学习率,使用批量标准化,并仔细检查我的输入是否真的不同(它们是!)。

关于导致此问题的原因以及如何解决此问题的任何想法?感谢您提供的任何意见!

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)