问题描述
为什么当我将网络设置为eval时,模型在验证数据上的误差函数在最初的几个时期就显得过高了。
如果我使用model.eval(),则在前4-5个时期内误差大于40-50k,然后迅速下降至3-4,但是如果我将网络留在model.train()上,则误差为只有5-6。
def eval_model(DataLoader,model,criterion,device,withStat,withImage):
model.eval()
eval_epochen_loss = 0
img = None
n_eval = 0
TP,TN,FP,FN = 0,0
stat = None
for i,data in enumerate(DataLoader):
dicoms,targets = data
dicoms,targets = Variable(dicoms.to(device)),Variable(targets.to(device))
assert targets.shape[1] - 1 == model.n_classes,\
f'Network has been defined with {model.n_classes} output classes,' \
f'but loaded target have {targets.shape[1] - 1} channels. Please check the labeled data or adjust ' \
f'the network classes. '
preds = model(dicoms)
loss = criterion(preds,targets[:,-1,:,:].long())
eval_epochen_loss += loss.item()
n_eval += dicoms.shape[0]
if withStat:
res = TP_TN_FP_FN_in_batch(targets[:,:].cpu().detach().long(),preds.detach().cpu())
TP += res[0]
TN += res[1]
FP += res[2]
FN += res[3]
if withStat:
stat = Statistic(TP,FN)
if withImage:
img_np = np.array(draw_images((dicoms,-1],preds),outline_bool=True)).transpose(0,3,1,2)
img = torch.from_numpy(img_np)
return eval_epochen_loss / n_eval,stat,img
解决方法
在不查看模型的情况下很难诊断问题,但是批处理规范在评估模式和训练模式下的行为会有所不同。这可能是您遇到此问题的原因。 Batchnorm使用整个训练批次在训练期间获取其参数,但使用在验证/测试期间使用的存储的运行平均值。似乎这个运行平均值需要几个纪元才能很好地收敛到您的验证集。您的数据是否正确归一化?