问题描述
我正在PyTorch中编写一个众所周知的问题MNIST database of handwritten digits
的代码。我从主要网站下载了训练和测试数据集,包括标记的数据集。数据集格式为t10k-images-idx3-ubyte.gz
,提取后为t10k-images-idx3-ubyte
。我的数据集文件夹看起来像
MINST
Data
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
def load_dataset():
data_path = "/home/MNIST/Data/"
xy_trainPT = torchvision.datasets.ImageFolder(
root=data_path,transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
xy_trainPT,batch_size=64,num_workers=0,shuffle=True
)
return train_loader
我的代码显示Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp
如何解决此问题,我还想检查是否从数据集中加载了我的图像(只是一个数字包含前5个图像)?
解决方法
欢迎来到stackoverflow!
MNIST数据集不存储为图像,而是以二进制格式(如ubyte扩展名所示)存储。因此,ImageFolder
不是您想要的类型数据集。相反,您将需要使用MNIST dataset class。如果您还没有下载数据,它甚至可以下载:)
这是一个数据集类,因此只需使用正确的root
路径实例化,然后将其作为数据加载器的参数,一切就可以正常工作。
如果要检查图像,只需使用数据加载器的get
方法,然后将结果保存为png文件(您可能需要先将张量转换为numpy数组)
阅读此Extract images from .idx3-ubyte file or GZIP via Python
更新
您可以使用此格式导入数据
xy_trainPT = torchvision.datasets.MNIST(
root="~/Handwritten_Deep_L/",train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),)
现在,download=True
发生的事情是,您的代码将首先在根目录(给定路径)中检查是否包含任何数据集。
如果no
,则将从网络上下载数据集。
如果yes
此路径已经包含一个数据集,那么您的代码将使用现有的数据集运行,而不会从互联网上下载。
您可以检查,首先给出一个路径without any dataset
(将从Internet下载数据),然后给出另一个路径which already contains dataset
则不会下载数据。