问题描述
例如,对于cifar10数据集,直接使用pytorch附带的数据集,在相同的网络结构下,准确率可以达到96%,但是在我将cifar10转换为图片之后,我对其进行了测试和准确性率只有92%。为什么?
这是先前的代码:
train_dataset = dset.CIFAR10(args.data_path,train=True,transform=train_transform,download=True)
test_dataset = dset.CIFAR10(args.data_path,train=False,transform=test_transform,download=True)
这是修改后的代码:
train_dataset = datasets.ImageFolder(root='/home/ubuntu/bigdisk/DataSets/cifar10/static/orig/train/',transform=train_transform
)
test_dataset = datasets.ImageFolder(root='/home/ubuntu/bigdisk/DataSets/cifar10/static/orig/test/',transform=test_transform
)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size,shuffle=True,num_workers=args.prefetch,pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=args.test_bs,shuffle=False,pin_memory=True)
解决方法
如果下载的数据集,超参数(例如批大小或学习率),数据集转换等相等,我认为是因为随机性。
您的数据加载器会随机随机播放数据集。每次重新组合后,重新组合后的数据集始终会有所不同,这可能会导致准确性差异。
此外,每次都会用不同的值初始化模型。 (除非您使用了一些总是使用相同值初始化模型的初始化方法。)
您可以检查https://pytorch.org/docs/stable/notes/randomness.html以获得更多信息。