带有 Tensorflow Hub 的 Keras 模型在保存/恢复时不会给出相同的结果

问题描述

我有一个使用 Tensorflow hub 层的 Keras 模型。然而,该模型在原始模型和恢复模型之间没有给出相同的预测。

我的 Keras 模型:

hub_layer = hub.KerasLayer("https://tfhub.dev/google/remote_sensing/eurosat-resnet50/1",tags=['train'],input_shape=(64,64,3))

original_model = Sequential()
original_model.add(hub_layer)
original_model add(Dense(32,activation='relu'))
original_model.add(Dense(1,activation='sigmoid'))

original_model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])

history = original_model.fit(train_generator,epochs=100)

img_batch = ... # Image batch of shape (32,3)
original_model.predict(img_batch)

原始模型输出

Out[1] : array([[0.803754  ],[0.2758078 ],...
               [0.26074764],[0.6190501 ]]

当模型被保存和恢复时,预测是不一样的:

orignial_model.save("model.hd5")

restored_model = tf.keras.models.load_model("model.hd5",custom_objects={'KerasLayer': hub.KerasLayer})

restored_model.predict(img_batch) # The image batch used is exactly the same as before

输出恢复模型:

Out[2] : array([[0.9999999 ],[1.        ],...
                [1.        ],[1.        ]]

结果与原始模型不一样。

我尝试了相同的实验,但没有使用 Tensorflow 层,问题没有出现。所以我猜问题出在这个 Tensorflow hub 层。

我还尝试比较了两种型号的配置,它们是相同的:

original_model.get_config() == restored_model.get_config() # Return true

我还比较了两个模型的权重,它们是相同的。

版本:

  • 张量流:2.0.0
  • Keras:2.3.1
  • Tensorflow 中心:0.8.0
  • Python:3.7.10

解决方法

我找不到问题的根源,但我找到了替代解决方案。

此解决方案包括从模型中删除 tensorflow 中心层。由这一层完成的转换可以像这样在外面完成:

hub_layer = hub.KerasLayer("https://tfhub.dev/google/remote_sensing/eurosat-resnet50/1",tags=['train'],input_shape=(64,64,3))

original_model = Sequential()
original_model.add(Input(2048))
original_model add(Dense(32,activation='relu'))
original_model.add(Dense(1,activation='sigmoid'))

original_model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])

X_transformed = hub_layer(X)

history = original_model.fit(X_transformed,y,epochs=100) 

通过这样做,原始模型和恢复模型之间的模型预测是相同的。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...