问题描述
我已经构建了一个自定义 keras 模型,并且在它的前向传递过程中,它使用了另一个库中函数的输出。但是,此函数的参数必须是一个 numpy 数组。在 model.compile()
期间,我可以将 run_eagerly
参数设置为 True,然后我可以使用 EagerTensor 的 .numpy()
方法将输出从前向传递转换为 numpy,但这似乎计算效率不高,因为.numpy()
在我的网络中只需要一次。 如何将张量转换为仅用于一次计算的热切张量?这可能吗?我曾尝试使用 K.get_session()
和 hidden_layer_outputs = sess.run(hidden_layer_outputs)
,但这会引发“无法在 TensorFlow 图形函数中获取会话”错误。下面是一个说明我的问题的例子。
def third_party_library_fxn(np_arr):
"""Uses len to get the first dimension of the array.
:param np_arr: <class 'numpy.ndarray'>
"""
# First part of function is to get the 1st dimension of the array
first_dim = len(np_arr)
# This function does more comptuations on first_dim
# ....
# ....
# return results
import tensorflow as tf
from tensorflow.python.framework.ops import EagerTensor
from third_party_library import third_party_library_fxn
class CustomModel(tf.keras.Model):
def __init__(self,units,**kwargs):
self.dense = tf.keras.layers.Dense(units=units)
self.lambd = tf.keras.layers.Lambda(third_party_library_fxn)
def call(self,inputs):
hidden_layer_outputs = self.dense(inputs)
# Raises Error: This block executes if the model is NOT running eagerly (i.e.,during model.fit)
if not(isinstance(hidden_layer_outputs,EagerTensor)):
# You cannot calculate `len()` of `tf.Tensor`
outputs_from_third_party_library = self.lambd(hidden_layer_outputs)
# No Error: This block is executed if the model IS running eagerly
else:
outputs_from_third_party_library = self.lambd(hidden_layer_outputs.numpy())
引起错误的编译和拟合:
model = CustomModel(units=arbitrary_number)
# Compilation is NOT eager by default
model.compile(loss=arbitrary_loss,optimizer=arbitrary_optimizer,run_eagerly=False)
# Raises error
model.fit(arbitrary_tf_batch_data,epochs=arbitrary_epochs)
编辑 1:
编辑 2:
我尝试过的一种方法是将第三方库函数包装到一个 tf.keras.layers.Layer
类中,然后设置 dynamic=True
。这不能解决问题,因为 model.fit(...,run_eagerly=False)
失败并出现以下错误。同样的错误建议使用 tf.py_function
,但在之前的努力中,我发现将第三方库函数包装在 tf.py_function
中并没有解决问题(更多关于这方面的内容也尽快)。
ValueError:您的模型包含只能在 Eager Execution 中成功运行的层(使用 dynamic=True
构造的层)。您不能设置 run_eagerly=False
。
最好,
贾里德
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)