问题描述
我正在训练一个 CNN,并通过定义通过 Keras 层应用数据增强:
int characterCount = 0;
int i = 0,j = 0;
while (i < wordData.length) {
j = 0;
while (j < wordData[i].length) {
characterCount += (wordData[i][j]).length();
j++;
}
i++;
}
System.out.println(characterCount);
以下是相关代码的片段:
data_augmentation = keras.Sequential([
layers.experimental.preprocessing.Randomrotation(factor=0.4,fill_mode="wrap"),layers.experimental.preprocessing.RandomTranslation(height_factor=0.2,width_factor=0.2,layers.experimental.preprocessing.RandomFlip("horizontal"),layers.experimental.preprocessing.RandomContrast(factor=0.2),layers.experimental.preprocessing.RandomHeight(factor=0.2),layers.experimental.preprocessing.RandomWidth(factor=0.2)
])
如果我在 def process(x,y):
x = DATASETS_DIR + "/" + x + ".jpg"
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x,channels=3)
x = tf.image.resize(x,[299,299])
x = layers.experimental.preprocessing.Rescaling(1./255)(x)
return x,y
def process_with_augmentation(x,y):
x,y = process(x,y)
x = data_augmentation(x)
return x,y
train_ds = train_ds.map(process_with_augmentation,num_parallel_calls=tf.data.experimental.AUTOTUNE)
validation_ds = validation_ds.map(process,num_parallel_calls=tf.data.experimental.AUTOTUNE)
中注释掉 x = data_augmentation(x)
,代码工作正常。如果我不注释掉,我会收到以下错误:
process_with_augmentation()
关于如何解决这个问题有什么想法吗?
解决方法
keras 层需要批量大小。当您将函数映射到 tf.data.Dataset 时,图像将缺少该批次维度。您可以通过在调用 keras 预处理模型之前添加维度来解决此问题:
def process_with_augmentation(x,y):
x,y = process(x,y)
x = tf.expand_dims(x,axis=0) # adding the batch dimension
x = data_augmentation(x)
x = tf.squeeze(x,[0]) # removing the batch dimension
return x,y
另一种选择是在调用模型之前调用 tf.data.Dataset.batch
,即:
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.map(lambda x,y:(data_augmentation(x),y))
如果你想去掉那个额外的维度,你可以在增强的数据集上调用 unbatch
。
不过,这不是解决这个问题的最优雅的方法。我建议将您的预处理模型直接集成到您的神经网络中,或者使用 tf.image
模块中的函数重写您的预处理逻辑。