低级Tensorflow,dataset.as_numpy_iterator返回dicts而不是numpy数组

问题描述

当我尝试使用https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch中的方法导入和批处理数据集时
当我使用dataset.as_numpy_iterator()时,即使我应该得到一堆numpy数组,迭代对象也是字典。
我的代码如下:

dataset = tfds.load('mnist',split='train')
dataset.batch(batch_size,drop_remainder=False)
for i in dataset.as_numpy_iterator():
    print(type(i))  # prints <class 'dict'>

为什么会发生?

解决方法

使用as_supervised =真

import tensorflow_datasets as tfds
dataset = tfds.load('mnist',split='train',as_supervised=True)
dataset.batch(10,drop_remainder=False)
for image,label in tfds.as_numpy(dataset):
    print(type(image),type(label),label)

根据TensorFlow文档,如果as_supervised为False,您将获得字典值。 检查文档Here