您如何使用PyTorch数据集访问S3和其他对象存储提供程序上的CSV数据?

问题描述

我的数据集作为CSV文件的集合存储在Amazon Web Services(AWS)简单存储服务(S3)存储桶中。我想根据此数据训练PyTorch模型,但内置的Dataset类不提供对对象存储服务(如S3或Google Cloud Storage(GCS),Azure Blob存储等)的本地支持。我在https://pytorch.org/docs/stable/data.html#处查看了PyTorch文档中有关可用的数据集类的信息,当涉及到公共云对象存储支持时,会显得很短。

看来我必须根据以下说明创建自己的自定义数据集:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class,但工作似乎不胜枚举:我需要弄清楚如何将数据从对象存储下载到本地节点,并解析CSV文件将其读取到PyTorch张量中,然后处理由于我的数据集为100 s GB而导致磁盘空间用尽的可能性。

由于PyTorch模型是使用梯度下降训练的,我只需要一次将一小部分数据(小于1GB)存储在内存中,是否有一个自定义的数据集实现方式可以帮助您?

解决方法

查看支持诸如S3和GCS osds.readthedocs.io/en/latest/gcs.html之类的对象存储服务的ObjectStorage数据集

您可以运行

pip install osds

安装它,然后将其指向您的S3存储桶,以使用类似

的方法实例化PyTorch Dataset和DataLoader。
from osds.utils import ObjectStorageDataset
from torch.utils.data import DataLoader


ds = ObjectStorageDataset(f"gcs://gs://cloud-training-demos/taxifare/large/taxi-train*.csv",storage_options = {'anon' : False },batch_size = 32768,worker = 4,eager_load_batches = False)

dl = DataLoader(ds,batch_size=None)

使用S3位置路径代替gcs://gs://cloud-training-demos/taxifare/large/taxi-train*.csv。因此,根据存储桶和存储数据集CSV对象的存储桶目录,S3的全局名称类似于s3://<bucket name>/<object path>/*.csv

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...