使用 tensorflow 提取 ELMo 特征并将它们转换为 numpy TF 2.xTF 1.x

问题描述

所以我有兴趣使用 ELMo 模型提取句子嵌入。

我一开始就尝试过:

import tensorflow as tf
import tensorflow_hub as hub
import numpy as np

elmo_model = hub.Module("https://tfhub.dev/google/elmo/2",trainable=True)

x = ["Hi my friend"]

embeddings = elmo_model(x,signature="default",as_dict=True)["elmo"]


print(embeddings.shape)
print(embeddings.numpy())

它运行良好,直到最后一行,我无法将其转换为 numpy 数组。

搜索了一下,发现如果把下面这行代码放在代码的开头,问题肯定能解决

tf.enable_eager_execution()

然而,我把它放在代码的开头,我意识到我无法编译

elmo_model = hub.Module("https://tfhub.dev/google/elmo/2",trainable=True)

我收到此错误

Exporting/importing Meta graphs is not supported when eager execution is enabled. No graph exists when eager execution is enabled.

我该如何解决我的问题?我的目标是获取句子特征并在 NumPy 数组中使用它们。

提前致谢

解决方法

TF 2.x

TF2 行为更接近经典的 Python 行为,因为它默认为急切执行。但是,您应该使用 hub.load 在 TF2 中加载您的模型。

elmo = hub.load("https://tfhub.dev/google/elmo/2").signature["default"]
x = ["Hi my friend"]
embeddings = elmo(tf.constant(x))["elmo"]

然后,您可以访问结果并使用 numpy 方法将它们转换为 numpy 数组。

>>> embeddings.numpy()
array([[[-0.7205108,-0.27990735,-0.7735629,...,-0.24703965,-0.8358178,-0.1974785 ],[ 0.18500198,-0.12270843,-0.35163105,0.14234722,0.08479916,-0.11709933],[-0.49985904,-0.88964033,-0.30124515,0.15846594,0.05210422,0.25386307]]],dtype=float32)

TF 1.x

如果使用 TF 1.x,您应该在 tf.Session 中运行该操作。 TensorFlow 不使用 Eager Execution,需要先构建图,然后在会话中评估结果。

elmo_model = hub.Module("https://tfhub.dev/google/elmo/2",trainable=True)
x = ["Hi my friend"]
embeddings_op = elmo_model(x,signature="default",as_dict=True)["elmo"]
# required to load the weights into the graph
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    embeddings = sess.run(embeddings_op)

在这种情况下,结果将已经是一个 numpy 数组:

>>> embeddings
array([[[-0.72051036,-0.27990723,-0.773563,-0.24703972,-0.83581805,-0.19747877],[ 0.18500218,-0.12270836,-0.35163072,0.08479934,[-0.49985906,-0.8896401,-0.3012453,0.15846589,0.05210405,0.2538631 ]]],dtype=float32)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...