在具有自定义损失的Pytorch中训练模型如何设置优化器并进行训练?

问题描述

我是pytorch的新手,我正在尝试运行找到的github模型并对其进行测试。因此,作者提供了模型和损失函数

像这样:

#1. Inference the model
model = PhysNet_padding_Encoder_Decoder_MAX(frames=128)
rPPG,x_visual,x_visual3232,x_visual1616 = model(inputs)

#2. normalized the Predicted rPPG signal and GroundTruth BVP signal
rPPG = (rPPG-torch.mean(rPPG)) /torch.std(rPPG)     # normalize
BVP_label = (BVP_label-torch.mean(BVP_label)) /torch.std(BVP_label)     # normalize

#3. Calculate the loss
loss_ecg = Neg_Pearson(rPPG,BVP_label)

数据加载

    train_loader = torch.utils.data.DataLoader(train_set,batch_size = 20,shuffle = True)

    batch = next(iter(train_loader))

    data,label1,label2 = batch

    inputs= data

假设我想训练这个模型15个纪元。 所以这就是我到目前为止: 我正在尝试设置优化程序和训练,但是我不确定如何将自定义损失和数据加载与模型联系起来并正确设置15个时期的训练。

optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

for epoch in range(15):
  ....

有什么建议吗?

解决方法

我假设BVP_label是train_loader的label

train_loader = torch.utils.data.DataLoader(train_set,batch_size = 20,shuffle = True)

# Using GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = PhysNet_padding_Encoder_Decoder_MAX(frames=128)
model.to(device)

optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

for epoch in range(15):
    model.train()
    for inputs,label1,label2 in train_loader:
        rPPG,x_visual,x_visual3232,x_visual1616 = model(inputs)
        BVP_label = label1 # assumed BVP_label is label1

        rPPG = (rPPG-torch.mean(rPPG)) /torch.std(rPPG)
        BVP_label = (BVP_label-torch.mean(BVP_label)) /torch.std(BVP_label)
        
        loss_ecg = Neg_Pearson(rPPG,BVP_label)
        
        optimizer.zero_grad()
        loss_ecg.backward()
        optimizer.step()

PyTorch培训步骤如下。

  • 创建DataLoader
  • 初始化模型和优化器
  • 创建设备对象并将模型移至设备

在火车圈中

  • 选择一个小批量数据
  • 使用模型进行预测
  • 计算损失
  • loss.backward()更新模型的梯度
  • 使用优化器更新参数

如您所知,您还可以查看PyTorch教程。

Learning PyTorch with Examples

What is torch.nn really?

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...