问题描述
我正在将我们的代码从 tensorflow 1 迁移到 tensorflow 2。其中一层是嵌入层,加载如下:
import tensorflow_hub as hub
model_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/1"
self.use_embed = hub.Module(model_url,trainable=False)
在 Tensorflow 2 中这将变成
import tensorflow_hub as hub
model_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
self.use_embed = hub.load(model_url)
hub.Module API 仅适用于 TF1。对于 TF2,切换到普通 SavedModels 和 hub.load()。
但是,load()
方法不支持 trainable
参数吗?
这个参数发生了什么变化,我如何在 Tensorflow 2 中应用它?
解决方法
Model Compatibility Guide 提到 hub.load()
和 hub.KerasLayer()
的参数名称不同:
使用 hub.load:
m = hub.load(handle)
输出 = m(inputs,training=is_training)
或 hub.KerasLayer:
m = hub.KerasLayer(handle,trainable=True)
输出 = m(输入)