问题描述
我正在研究Tensorflow中的图像分类。我正要从项目目录中将本地数据集加载到python文件中。我正在关注tensorflow文档(https://www.tensorflow.org/tutorials/images/classification),当我到达添加数据点时,该文档会使用Google数据集从互联网上导入数据。他们使用
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
然后
data_dir = tf.keras.utils.get_file('flower_photos',origin=dataset_url,untar=True)
我该如何使用名为DataSet的本地目录来做同样的事情?
解决方法
假设您的数据集包含包含image.png的子文件夹。
import pathlib
data_dir = pathlib.Path('path/to/your/DataSet_folder')
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*.png'))
list_ds包含图像的所有路径。
, get_file
仅在不存在时下载。因此,您可以将fname
设置为本地文件,并像这样设置origin = ''
:
data_dir = tf.keras.utils.get_file(os.path.abspath('flower_photos'),origin='',untar=True)
os.path.abspath
是必需的,因为默认情况下keras
搜索cache_dir
来查找文件。
并且由于untar
已过时,您最好使用extract
代替:
data_dir = tf.keras.utils.get_file(os.path.abspath('flower_photos.tar.gz'),extract=True)