前几个时期的误差极高Pytorch-图像分割

问题描述

为什么当我将网络设置为eval时,模型在验证数据上的误差函数在最初的几个时期就显得过高了。

如果我使用model.eval(),则在前4-5个时期内误差大于40-50k,然后迅速下降至3-4,但是如果我将网络留在model.train()上,则误差为只有5-6。

def eval_model(DataLoader,model,criterion,device,withStat,withImage):
    model.eval()
    eval_epochen_loss = 0
    img = None
    n_eval = 0
    TP,TN,FP,FN = 0,0
    stat = None

    for i,data in enumerate(DataLoader):
        dicoms,targets = data
        dicoms,targets = Variable(dicoms.to(device)),Variable(targets.to(device))

        assert targets.shape[1] - 1 == model.n_classes,\
            f'Network has been defined with {model.n_classes} output classes,' \
            f'but loaded target have {targets.shape[1] - 1} channels. Please check the labeled data or adjust ' \
            f'the network classes. '

        preds = model(dicoms)
        loss = criterion(preds,targets[:,-1,:,:].long())

        eval_epochen_loss += loss.item()
        n_eval += dicoms.shape[0]

        if withStat:
            res = TP_TN_FP_FN_in_batch(targets[:,:].cpu().detach().long(),preds.detach().cpu())
            TP += res[0]
            TN += res[1]
            FP += res[2]
            FN += res[3]

    if withStat:
        stat = Statistic(TP,FN)

    if withImage:
        img_np = np.array(draw_images((dicoms,-1],preds),outline_bool=True)).transpose(0,3,1,2)
        img = torch.from_numpy(img_np)

    return eval_epochen_loss / n_eval,stat,img

解决方法

在不查看模型的情况下很难诊断问题,但是批处理规范在评估模式和训练模式下的行为会有所不同。这可能是您遇到此问题的原因。 Batchnorm使用整个训练批次在训练期间获取其参数,但使用在验证/测试期间使用的存储的运行平均值。似乎这个运行平均值需要几个纪元才能很好地收敛到您的验证集。您的数据是否正确归一化?