PyTorch模型精度

问题描述

我下面有一个模型,该模型可以训练对图片中的多个对象进行计数的图像。我想计算模型的准确性。此刻的代码计算出训练损失,错误以及测试损失和错误。为了找到最终精度,我需要多少%?非常感谢

# main training loop
global_step = 0
best_test_error = 10000
for epoch in range(15):
    print("Epoch %d" % epoch)
    model.train()
    for images,paths in tqdm(loader_train):
        images = images.to(device)
        targets = torch.tensor([metadata['count'][os.path.split(path)[-1]] for path in paths]) # B
        targets = targets.float().to(device)

        # forward pass:
        output = model(images) # B x 1 x 9 x 9 (analogous to a heatmap)
        preds = output.sum(dim=[1,2,3]) # predicted cell counts (vector of length B)
        
        # backward pass:
        loss = torch.mean((preds - targets)**2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # logging:
        count_error = torch.abs(preds - targets).mean()
        writer.add_scalar('train_loss',loss.item(),global_step=global_step)
        writer.add_scalar('train_count_error',count_error.item(),global_step=global_step)

        print("Step %d,loss=%f,count error=%f" % (global_step,count_error.item()))

        global_step += 1
    
    mean_test_error = 0
    model.eval()
    for images,paths in tqdm(loader_test):
        images = images.to(device)
        targets = torch.tensor([metadata['count'][os.path.split(path)[-1]] for path in paths]) # B
        targets = targets.float().to(device)

        # forward pass:
        output = model(images) # B x 1 x 9 x 9 (analogous to a heatmap)
        preds = output.sum(dim=[1,3]) # predicted cell counts (vector of length B)

        # logging:
        loss = torch.mean((preds - targets)**2)
        count_error = torch.abs(preds - targets).mean()
        mean_test_error += count_error
        writer.add_scalar('test_loss',global_step=global_step)
        writer.add_scalar('test_count_error',global_step=global_step)
        
        global_step += 1
    
    mean_test_error = mean_test_error / len(loader_test)
    print("Test count error: %f" % mean_test_error)
    if mean_test_error < best_test_error:
        best_test_error = mean_test_error
        torch.save({'state_dict':model.state_dict(),'optimizer_state_dict':optimizer.state_dict(),'globalStep':global_step,'train_paths':dataset_train.files,'test_paths':dataset_test.files},checkpoint_path)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...