使用 tf.cond() 检查自定义层的调用方法中的条件

问题描述

我正在 tensorflow 2.x 中实现一个自定义层。我的要求是,程序应该在返回输出之前检查条件。

class SimpleRNN_cell(tf.keras.layers.Layer):
    def __init__(self,M1,M2,fi=tf.nn.tanh,disp_name=True):
        super(SimpleRNN_cell,self).__init__()
        pass        
    def call(self,X,hidden_state,return_state=True):
        y = tf.constant(5)
        if return_state == True:
            return y,self.h
        else:
            return y

我的问题是:我应该继续使用当前代码(假设 tape.gradient(Loss,self.trainable_weights) 可以正常工作)还是应该使用 tf.cond()。 此外,如果可能,请说明在何处使用 tf.cond() 以及在何处不使用。我没有找到太多关于这个主题内容

解决方法

tf.cond 仅在基于可微计算图中的数据执行条件评估时才相关。 (https://www.tensorflow.org/api_docs/python/tf/cond) 这在默认图形模式的 TF 1.0 中尤为必要。对于急切模式,GradientTape 系统还允许使用诸如 if ...: (https://www.tensorflow.org/guide/autodiff#control_flow)

之类的 Python 构造进行条件数据流

然而,仅仅根据配置参数提供不同的行为,不依赖于计算图的数据并且在模型运行时固定,使用简单的 python if 语句是正确的。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...