问题描述
我需要在类方法中计算 tf.Variable
梯度,但稍后使用这些梯度以不同的方法更新变量。我可以在不使用 @tf.function
装饰器时执行此操作,但是在使用 TypeError: An op outside of the function building code is being passed a "Graph" tensor
时出现 @tf.function
错误。我一直在寻找有关此错误的理解以及如何解决它,但结果很短。
如果您好奇,仅供参考,我想这样做是因为我有许多不同方程中的变量。与其尝试创建一个关联所有变量的单一方程,更容易(计算成本更低)将它们分开,及时计算每个方程的梯度,然后逐步应用更新。我认识到这两种方法在数学上并不相同。
这是我的代码(一个最小的例子),然后是结果和错误信息。请注意,当计算梯度并用于在单个方法 .iterate()
中更新变量时,没有错误。
import tensorflow as tf
class Example():
def __init__(self,x,y,target,lr=0.01):
self.x = x
self.y = y
self.target = target
self.lr = lr
self.variables = [self.x,self.y]
@tf.function
def iterate(self):
with tf.GradientTape(persistent=False) as tape:
loss = (self.target - self.x * self.y)**2
self.gradients = tape.gradient(loss,self.variables)
for g,v in zip(self.gradients,self.variables):
v.assign_add(-self.lr * g)
@tf.function
def compute_update(self):
with tf.GradientTape(persistent=False) as tape:
loss = (self.target - self.x * self.y)**2
self.gradients = tape.gradient(loss,self.variables)
@tf.function
def apply_update(self):
for g,self.variables):
v.assign_add(-self.lr * g)
x = tf.Variable(1.)
y = tf.Variable(3.)
target = tf.Variable(5.)
example = Example(x,target)
# Compute and apply updates in a single tf.function method
example.iterate()
print('')
print(example.variables)
print('')
# Compute and apply updates in separate tf.function methods
example.compute_update()
example.apply_update()
print('')
print(example.variables)
print('')
输出:
$ python temp_bug.py
[<tf.Variable 'Variable:0' shape=() dtype=float32,numpy=1.12>,<tf.Variable 'Variable:0' shape=() dtype=float32,numpy=3.04>]
Traceback (most recent call last):
File "temp_bug.py",line 47,in <module>
example.apply_update()
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py",line 580,in __call__
result = self._call(*args,**kwds)
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py",line 650,in _call
return self._concrete_stateful_fn._filtered_call(canon_args,canon_kwds) # pylint: disable=protected-access
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py",line 1665,in _filtered_call
self.captured_inputs)
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py",line 1746,in _call_flat
ctx,args,cancellation_manager=cancellation_manager))
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py",line 598,in call
ctx=ctx)
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py",line 75,in quick_execute
raise e
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py",line 60,in quick_execute
inputs,attrs,num_outputs)
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example,the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: gradient_tape/mul/Mul:0
解决方法
请检查下面与您的问题相对应的快速修复。
class Example:
def __init__(self,x,y,target,lr=0.01):
self.x = tf.Variable(x,dtype=tf.float32)
self.y = tf.Variable(y,dtype=tf.float32)
self.target = tf.Variable(target,dtype=tf.float32)
self.lr = lr
self.variables = [self.x,self.y]
@tf.function
def iterate(self):
with tf.GradientTape() as tape:
loss = (self.target - self.x * self.y)**2
#it is rather dangerous to use self.gradients here
grads = tape.gradient(loss,self.variables)
for g,v in zip(grads,self.variables):
v.assign_add(-self.lr * g)
@tf.function
def compute_update(self):
with tf.GradientTape() as tape:
loss = (self.target - self.x * self.y)**2
#return a list of gradients
return tape.gradient(loss,self.variables)
@tf.function
def apply_update(self,grad): #receive the gradients as arguments
for g,v in zip(grad,self.variables):
v.assign_add(-self.lr * g)
example = Example(1,3,5)
example.iterate()
print(example.variables)
example.apply_update(example.compute_update())
print(example.variables)
这与tf.function
的机制有关。当您将“self”作为对象传递给由 tf.function
包装的函数时,self
下的每个属性(如 self.lr
、self.variables
等)都应该是 constant 除非它是一个 tf.Variable
并且被 assign
、assign_add
等修改。如果你这样做:
@tf.function
def iterate(self):
with tf.GradientTape() as tape:
loss = (self.target - self.x * self.y)**2
grads = tape.gradient(loss,self.variables)
tf.print(self.lr) #let it prints self.lr
for g,self.variables):
v.assign_add(-self.lr * g)
example.iterate() #print 0.01
example.lr=0.03
example.iterate() #print 0.01 again! instead of 0.03
这就是为什么 self.gradients 在那里发生变化是危险的。如需更多信息:https://www.tensorflow.org/guide/function