使用tf.data.dataset为序列模型创建数据生成器

问题描述

我有一个包含RGB图像的图像数据集:img1.png,img2.png ... img250.png。我从每个图像中提取了100个大小分别为[64,64,3]的小补丁。因此,我现在有了img1_1.png,img1_2.png ... img1_100.png,img2_1.png,img2_2.png,... img2_100.png,img3_1,....

我想用tf.data.dataset.from_tensor_slices创建一个数据生成器,以将每个图像的所有补丁传递到RNN模型。所以,我想生成器创建如下输出:[batch_size,100,64,64,3]

我该怎么做?

解决方法

代码:

# generating data
x = tf.constant(np.random.randint(256,size =(250,64,3)),dtype = tf.int32)

# Creating a dataset with sequence length
dataset = tf.data.Dataset.from_tensor_slices(x).batch(100,drop_remainder= True)
for i in dataset:
    print(i.shape)

输出:

(100,3)
(100,3)

确保drop_remainders = True

最后,创建所需长度的批量。

# creating dataset with batch_size
dataset = dataset.batch(32)
for i in dataset:
    print(i.shape)

输出:

(2,100,3)

如果您的数据大小为(250,64,64,3):

dataset = tf.data.Dataset.from_tensor_slices(x).batch(32)
for i in dataset:
    print(i.shape)

输出:

(32,3)
(32,3)
(26,3)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...