使用 Keras 微调通用句子编码器

问题描述

我正在尝试微调 Universal Sentence Encoder 并将新的编码器层用于其他用途。

import tensorflow as tf
from tensorflow.keras.models import Model,Sequential
from tensorflow.keras.layers import Dense,Dropout
import tensorflow_hub as hub

module_url = "universal-sentence-encoder"
model = Sequential([
    hub.KerasLayer(module_url,input_shape=[],dtype=tf.string,trainable=True,name="use"),Dropout(0.5,name="dropout"),Dense(256,activation="relu",name="dense"),Dense(len(y),activation="sigmoid",name="activation")
])

model.compile(optimizer="adam",loss="categorical_crossentropy",metrics=["accuracy"])
model.fit(X,y,batch_size=256,epochs=30,validation_split=0.25)

这奏效了。损失下降,准确性不错。现在我只想提取 Universal Sentence Encoder 层。但是,这就是我得到的。

enter image description here

  1. 你知道我该如何解决这个 nan 问题吗?我希望看到数值的编码。
  2. 是否只能按照 this post 的建议将 tuned_use 层保存为模型?理想情况下,我想像 tuned_use 一样保存 Universal Sentence Encoder 层,以便我可以像 hub.KerasLayer(tuned_use_location,dtype=tf.string) 一样打开和使用它。

解决方法

希望这会对某人有所帮助,我最终使用 universal-sentence-encoder-4 而不是 universal-sentence-encoder-large-5 解决了这个问题。我花了很多时间进行故障排除,但这很困难,因为输入数据没有问题并且模型训练成功。这可能是由于梯度爆炸问题,但无法将 gradient clippingLeaky ReLU 添加到原始架构中。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...