Python在tensorflow中导入本地数据集

问题描述

我正在研究Tensorflow中的图像分类。我正要从项目目录中将本地数据集加载到python文件中。我正在关注tensorflow文档(https://www.tensorflow.org/tutorials/images/classification),当我到达添加数据点时,该文档会使用Google数据集从互联网上导入数据。他们使用

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"

然后

data_dir = tf.keras.utils.get_file('flower_photos',origin=dataset_url,untar=True)

我该如何使用名为DataSet的本地目录来做同样的事情?

解决方法

假设您的数据集包含包含image.png的子文件夹。

 import pathlib

data_dir = pathlib.Path('path/to/your/DataSet_folder')

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*.png'))

list_ds包含图像的所有路径。

,

get_file仅在不存在时下载。因此,您可以将fname设置为本地文件,并像这样设置origin = ''

data_dir = tf.keras.utils.get_file(os.path.abspath('flower_photos'),origin='',untar=True)

os.path.abspath是必需的,因为默认情况下keras搜索cache_dir来查找文件。

并且由于untar已过时,您最好使用extract代替:

data_dir = tf.keras.utils.get_file(os.path.abspath('flower_photos.tar.gz'),extract=True)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...