问题描述
让我们假设我们有一些模型,其中包括自定义损失和度量标准,这些模型在培训期间很重要。是否可以保存没有自定义对象的完整模型(权重+ graphdef / pb文件)?
推断过程中不需要自定义损失和指标,因此...
tf.keras.models.load_model("some_model",custom_objects={...})
...只会使推理代码更加复杂,因为需要包括自定义目标代码进行推理(尽管未使用)。
但是,tf.keras.callbacks.ModelCheckpoint
(甚至使用include_optimizer=False
)以及调用model.save()
总是保存模型定义包括自定义对象。
因此,只需在模型中加载...
tf.keras.models.load_model("some_model")
...将始终失败并抱怨缺少自定义对象。
是否可以以某种方式保存整个模型而无需自定义损失/指标?要获得易于加载的网络“推断”版本?
当然,人们可以简单地使用model.save_weights()
,但是随后需要包括实际代码以供以后推理,这是不希望的。
解决方法
如果目的是防止加载损失和指标,则可以在compile
中使用参数load_model
:
model = tf.keras.models.load_model("some_model",compile=False)
由于未编译模型,因此应跳过损耗和指标/优化器的要求。当然,您现在不能训练模型,但是使用model.predict()