EfficientNets上的转移学习如何处理灰度图像?

问题描述

我的问题更多地涉及算法如何工作。我已经成功地为灰度图像实现了EfficientNet集成和建模,现在我想了解它为什么起作用。

最重要的方面是灰度及其1通道。当我放入channels=1时,该算法不起作用,因为如果我理解正确,它是在3通道图像上制作的。当我放channels=3时,它可以正常工作。

所以我的问题是,当我放置channels = 3并用channels=1向模型提供经过预处理的图像时,为什么它继续起作用?

EfficientNetB5的代码

# Variable assignments
num_classes = 9
img_height = 84
img_width = 112
channels = 3
batch_size = 32

# Make the input layer
new_input = Input(shape=(img_height,img_width,channels),name='image_input')

# Download and use EfficientNetB5
tmp = tf.keras.applications.EfficientNetB5(include_top=False,weights='imagenet',input_tensor=new_input,pooling='max')
model = Sequential()
model.add(tmp)  # adding EfficientNetB5
model.add(Flatten())
...

预处理为灰度代码

data_generator = ImageDataGenerator(
        validation_split=0.2)

train_generator = data_generator.flow_from_directory(
        train_path,target_size=(img_height,img_width),batch_size=batch_size,color_mode="grayscale",###################################
        class_mode="categorical",subset="training")

解决方法

这很有趣。如果即使输入是灰度的训练仍然适用于 channels = 3,我会检查 train_generator 的批次形状(可能打印几个批次以了解它)。这是一个快速检查批次形状的代码片段。 (plotImages() 在 Tensorflow 文档中可用)

imgs,labels = next(train_generator)
print('Batch shape: ',imgs.shape)
plotImages(imgs,labels)