如何在文本的卷积层中使用 tensorflow hub 层嵌入?

问题描述

我是 TensorFlow 集线器的新手,我正在尝试在我的 Conv1D 网络中使用集线器嵌入层进行文本分类

在顺序模型中使用集线器嵌入层没有任何问题:

hub_layer = hub.KerasLayer("https://tfhub.dev/google/nnlm-en-dim50/2",input_shape=[],dtype=tf.string,trainable=False)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(128))
model.add(tf.keras.layers.Activation('relu'))
model.add(tf.keras.layers.Dense(5))
model.add(tf.keras.layers.Activation('softmax'))

model.compile(loss="categorical_crossentropy",optimizer="adam",metrics=["accuracy"])
model.summary()

但是,我无法在 Conv1D 模型中使用:

一个模型:

int_sequences_input = Input(shape=(max_length,))
embedded_sequences = hub_layer(int_sequences_input)
x = layers.Conv1D(128,5,activation="relu")(embedded_sequences)
x = layers.MaxPooling1D(5)(x)
x = layers.Conv1D(128,activation="relu")(x)
x = layers.GlobalMaxPooling1D()(x)
x = layers.Dense(128,activation="relu")(x)
x = layers.Dropout(0.5)(x)
preds = layers.Dense(len(class_names),activation="softmax")(x)
model = keras.Model(int_sequences_input,preds)
model.summary()

或:

第二个模型:

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Conv1D(128,7,activation='relu'))
model.add(tf.keras.layers.GlobalMaxPooling1D())
model.add(tf.keras.layers.Dense(64,activation='relu'))
model.add(tf.keras.layers.Dense(num_classes,activation='softmax'))

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
model.summary()

当我收到值错误时:

ValueError: Input 0 of layer conv1d_11 is incompatible with the layer: expected ndim=3,found ndim=2. Full shape received: [None,50]

我想知道是否有任何解决方案? 我查看了 thisthis,但没有一个解决我的问题。

解决方法

生成的嵌入维度为:(num_examples,embedding_dimension),它与 1D 卷积不兼容,因为它需要 3D 输入。

尝试在 hub 层之后重塑,像这样:

model.add(hub_layer)
model.add(tf.keras.layers.Reshape((1,50)))
model.add(tf.keras.layers.Conv1D(16,3,activation='relu',padding = 'same'))
...

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...