为什么 tf.cond 中的两个分支都被执行?为什么 tf.while_loop 完成循环,即使条件仍然为真?

问题描述

我现在使用 keras 有一段时间了,但通常我不需要使用自定义层或执行一些更复杂的流控制,所以我正在努力理解一些东西。

我正在建模一个顶部有自定义层的神经网络。这个自定义调用一个函数 (search_sigma),在这函数中我执行 tf.while_loop,在 tf.while_loop 中我执行 tf.cond

我不明白为什么这些条件不起作用。

  • tf.while_loop 停止,即使条件 (l1) 仍然为真
  • tf.cond executes f1f2(可调用对象 true_fnfalse_fn

有人能帮我理解我遗漏了什么吗?

我已经尝试为真正的张量更改 tf.cond 和 tf.while_loop 条件,只是想看看会发生什么。行为(完全相同的错误)保持不变。

我也尝试在不实现类的情况下编写此代码(仅使用函数)。什么都没有改变。

我试图通过查看 tensorflow 文档、其他堆栈溢出问题以及讨论 tf.while_loop 和 tf.cond 的网站来寻找解决方案。

我在代码正文中留下了一些 print() 以尝试跟踪正在发生的事情。

class find_sigma:
    
    def __init__ (self,t_inputs,inputs,expected_perp=10. ):       
        self.sigma,self.cluster = t_inputs
        self.inputs = inputs
        self.expected_perp = expected_perp
        self.min_sigma=tf.constant([0.01],tf.float32)
        self.max_sigma=tf.constant([50.],tf.float32)
 

    def search_sigma(self):

        
        def cond(s,sigma_not_found): return sigma_not_found


        def body(s,sigma_not_found):   

            print('loop')
            pi = K.exp( - K.sum( (K.expand_dims(self.inputs,axis=1) - self.cluster)**2,axis=2  )/(2*s**2) )        
            pi = pi / K.sum(pi)
            MACHINE_EPSILON = np.finfo(np.double).eps
            pi = K.maximum(pi,MACHINE_EPSILON)
            H = - K.sum ( pi*(K.log(pi)/K.log(2.)),axis=0 )
            perp = 2**H

            print('0')

            l1 = tf.logical_and (tf.less(perp,self.expected_perp),tf.less(0.01,self.max_sigma-s))
            l2 = tf.logical_and (tf.less(  self.expected_perp,perp),s-self.min_sigma) )
    
            def f1():
                print('f1')
                self.min_sigma = s 
                s2 = (s+self.max_sigma)/2 
                return  [s2,tf.constant([True])]
                

            def f2(l2): 
                tf.cond( l2,true_fn=f3,false_fn = f4)

            def f3(): 
                print('f3')
                self.max_sigma = s 
                s2 = (s+self.min_sigma)/2
                return [s2,tf.constant([True])]

            def f4(): 
                print('f4')
                return [s,tf.constant([False])]
            
            output = tf.cond( l1,f1,f4 ) #colocar f2 no lugar de f4

            s,sigma_not_found = output
            
            print('sigma_not_found = ',sigma_not_found)
            return [s,sigma_not_found]

        print('01')

        sigma_not_found = tf.constant([True])

        new_sigma,sigma_not_found=sigma_not_found = tf.while_loop(
            cond,body,loop_vars=[self.sigma,sigma_not_found]
        )

        print('saiu')
        
        print(new_sigma)

        return new_sigma

调用上面代码的那段代码是:

self.sigma = tf.map_fn(fn=lambda t: find_sigma(t,inputs).search_sigma(),elems=(self.sigma,self.clusters),dtype=tf.float32)

'inputs' 是一个 (None,10) 大小的张量

'self.sigma' 是一个 (10,) 大小的张量

'self.clusters' 是一个 (N,10) 大小的张量

解决方法

首先,你的第一个问题非常出色!大量信息!

tf.while_loop 非常令人困惑,这也是 tf 转向急切执行的原因之一。你不需要再这样做了。

无论如何,回到你的 2 个问题。两者的答案是相同的,您永远不会执行您的图表,您只是在构建它。在构建执行图时,tensorflow 需要跟踪您的 Python 代码,这就是您认为 tf.conf 正在运行 f1 和 f2 的原因。它是“某种运行”,因为它需要深入了解将添加到图形中的张量/操作。

同样适用于您关于 tf.while_loop 的问题。它永远不会执行那个。

我建议进行一些小的更改,这可能会帮助您理解我在说什么并解决您的问题。从 body 方法中删除该 tf.while_loop。创建另一个方法,假设 run() 并将循环移动到那里。有点像这样

def run(self):
   out = tf.while_loop(cond,body,loop_vars)

然后,调用 run()。它将强制执行图形。