问题描述
我现在使用 keras 有一段时间了,但通常我不需要使用自定义层或执行一些更复杂的流控制,所以我正在努力理解一些东西。
我正在建模一个顶部有自定义层的神经网络。这个自定义层调用另一个函数 (search_sigma
),在这个函数中我执行 tf.while_loop
,在 tf.while_loop
中我执行 tf.cond
。
我不明白为什么这些条件不起作用。
-
tf.while_loop
停止,即使条件 (l1
) 仍然为真 -
tf.cond executes
f1
和f2
(可调用对象true_fn
和false_fn
)
有人能帮我理解我遗漏了什么吗?
我已经尝试为真正的张量更改 tf.cond 和 tf.while_loop 条件,只是想看看会发生什么。行为(完全相同的错误)保持不变。
我也尝试在不实现类的情况下编写此代码(仅使用函数)。什么都没有改变。
我试图通过查看 tensorflow 文档、其他堆栈溢出问题以及讨论 tf.while_loop 和 tf.cond 的网站来寻找解决方案。
我在代码正文中留下了一些 print()
以尝试跟踪正在发生的事情。
class find_sigma:
def __init__ (self,t_inputs,inputs,expected_perp=10. ):
self.sigma,self.cluster = t_inputs
self.inputs = inputs
self.expected_perp = expected_perp
self.min_sigma=tf.constant([0.01],tf.float32)
self.max_sigma=tf.constant([50.],tf.float32)
def search_sigma(self):
def cond(s,sigma_not_found): return sigma_not_found
def body(s,sigma_not_found):
print('loop')
pi = K.exp( - K.sum( (K.expand_dims(self.inputs,axis=1) - self.cluster)**2,axis=2 )/(2*s**2) )
pi = pi / K.sum(pi)
MACHINE_EPSILON = np.finfo(np.double).eps
pi = K.maximum(pi,MACHINE_EPSILON)
H = - K.sum ( pi*(K.log(pi)/K.log(2.)),axis=0 )
perp = 2**H
print('0')
l1 = tf.logical_and (tf.less(perp,self.expected_perp),tf.less(0.01,self.max_sigma-s))
l2 = tf.logical_and (tf.less( self.expected_perp,perp),s-self.min_sigma) )
def f1():
print('f1')
self.min_sigma = s
s2 = (s+self.max_sigma)/2
return [s2,tf.constant([True])]
def f2(l2):
tf.cond( l2,true_fn=f3,false_fn = f4)
def f3():
print('f3')
self.max_sigma = s
s2 = (s+self.min_sigma)/2
return [s2,tf.constant([True])]
def f4():
print('f4')
return [s,tf.constant([False])]
output = tf.cond( l1,f1,f4 ) #colocar f2 no lugar de f4
s,sigma_not_found = output
print('sigma_not_found = ',sigma_not_found)
return [s,sigma_not_found]
print('01')
sigma_not_found = tf.constant([True])
new_sigma,sigma_not_found=sigma_not_found = tf.while_loop(
cond,body,loop_vars=[self.sigma,sigma_not_found]
)
print('saiu')
print(new_sigma)
return new_sigma
self.sigma = tf.map_fn(fn=lambda t: find_sigma(t,inputs).search_sigma(),elems=(self.sigma,self.clusters),dtype=tf.float32)
'inputs' 是一个 (None,10)
大小的张量
'self.sigma' 是一个 (10,)
大小的张量
'self.clusters' 是一个 (N,10)
大小的张量
解决方法
首先,你的第一个问题非常出色!大量信息!
tf.while_loop 非常令人困惑,这也是 tf 转向急切执行的原因之一。你不需要再这样做了。
无论如何,回到你的 2 个问题。两者的答案是相同的,您永远不会执行您的图表,您只是在构建它。在构建执行图时,tensorflow 需要跟踪您的 Python 代码,这就是您认为 tf.conf 正在运行 f1 和 f2 的原因。它是“某种运行”,因为它需要深入了解将添加到图形中的张量/操作。
同样适用于您关于 tf.while_loop 的问题。它永远不会执行那个。
我建议进行一些小的更改,这可能会帮助您理解我在说什么并解决您的问题。从 body 方法中删除该 tf.while_loop。创建另一个方法,假设 run() 并将循环移动到那里。有点像这样
def run(self):
out = tf.while_loop(cond,body,loop_vars)
然后,调用 run()。它将强制执行图形。