保存没有自定义对象的完整tf.keras模型?

问题描述

让我们假设我们有一些模型,其中包括自定义损失和度量标准,这些模型在培训期间很重要。是否可以保存没有自定义对象的完整模型(权重+ graphdef / pb文件)?

推断过程中不需要自定义损失和指标,因此...

tf.keras.models.load_model("some_model",custom_objects={...})

...只会使推理代码更加复杂,因为需要包括自定义目标代码进行推理(尽管未使用)。

但是,tf.keras.callbacks.ModelCheckpoint(甚至使用include_optimizer=False)以及调用model.save()总是保存模型定义包括自定义对象。

因此,只需在模型中加载...

tf.keras.models.load_model("some_model")

...将始终失败并抱怨缺少自定义对象。

是否可以以某种方式保存整个模型而无需自定义损失/指标?要获得易于加载的网络“推断”版本?

还是唯一的解决方案,将所有内容冻结为TFLite模型?

当然,人们可以简单地使用model.save_weights(),但是随后需要包括实际代码以供以后推理,这是不希望的。

解决方法

如果目的是防止加载损失和指标,则可以在compile中使用参数load_model

model = tf.keras.models.load_model("some_model",compile=False)

由于未编译模型,因此应跳过损耗和指标/优化器的要求。当然,您现在不能训练模型,但是使用model.predict()

进行推理应该可以正常工作