问题描述
在尝试评估通过稀疏张量运算获得的结果的导数时,我遇到了奇怪的行为。如果我在对它们进行操作之前将所有稀疏输入炸毁为密集输入,则以下代码按预期工作(以下代码的第一部分),但是当我对稀疏张量执行相同操作时,它会因 InvalidArgumentError
崩溃。此外,我收到 while_loop
警告,如下所示。在实际问题中,当然涉及更多的操作和更大、更多的张量,我基本上必须以稀疏模式收集 c
的条目。任何人都可以(更多)理解这种行为吗?
import tensorflow as tf
import numpy as np
a=tf.SparseTensor(indices=[[0,0],[1,1]],values=np.array([1,1],dtype=np.float32),dense_shape=(2,2))
b=tf.SparseTensor(indices=[[0,0]],values=np.array([-1,-1],2))
#dense mode...
f1=tf.Variable([1,dtype=np.float32)
with tf.GradientTape() as gtape:
c=tf.sparse.to_dense(a)*f1[0]+tf.sparse.to_dense(b)*f1[1]
print(gtape.jacobian(c,f1)) #... works fine
#sparse mode...
f2=tf.Variable([1,dtype=np.float32)
with tf.GradientTape() as gtape:
c=tf.sparse.add(a*f2[0],b*f2[1],0)
c=tf.sparse.to_dense(c)
print(gtape.jacobian(c,f2)) #... InvalidArgumentError
#WARNING:tensorflow:Using a while_loop for converting SparseAddGrad
#WARNING:tensorflow:Using a while_loop for converting SparseTensorDenseAdd
#WARNING:tensorflow:Using a while_loop for converting SparseTensorDenseAdd
#---------------------------------------------------------------------------
#InvalidArgumentError Traceback (most recent call last)
#<ipython-input-10-d449761ef6b2> in <module>
# 12 c=tf.sparse.add(a*f2[0],0)
# 13 c=tf.sparse.to_dense(c)
#---> 14 print(gtape.jacobian(c,f2)) #InvalidArgumentError
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\backprop.py in jacobian(self,target,sources,unconnected_gradients,parallel_iterations,experimental_use_pfor)
# 1187 try:
# 1188 output = pfor_ops.pfor(loop_fn,target_size,#-> 1189 parallel_iterations=parallel_iterations)
# 1190 except ValueError as err:
# 1191 six.reraise(
#c:\program files\python37\lib\site-packages\tensorflow\python\ops\parallel_for\control_flow_ops.py in pfor(loop_fn,iters,fallback_to_while_loop,parallel_iterations)
# 203 def_function.run_functions_eagerly(False)
# 204 f = def_function.function(f)
#--> 205 outputs = f()
# 206 if functions_run_eagerly is not None:
# 207 def_function.run_functions_eagerly(functions_run_eagerly)
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self,*args,**kwds)
# 826 tracing_count = self.experimental_get_tracing_count()
# 827 with trace.Trace(self._name) as tm:
#--> 828 result = self._call(*args,**kwds)
# 829 compiler = "xla" if self._experimental_compile else "nonXla"
# 830 new_tracing_count = self.experimental_get_tracing_count()
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self,**kwds)
# 893 # If we did not create any variables the trace we have is good enough.
# 894 return self._concrete_stateful_fn._call_flat(
#--> 895 filtered_flat_args,self._concrete_stateful_fn.captured_inputs) # pylint: disable=protected-access
# 896
# 897 def fn_with_cond(inner_args,inner_kwds,inner_filtered_flat_args):
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self,args,captured_inputs,cancellation_manager)
# 1917 # No tape is watching; skip to running the function.
# 1918 return self._build_call_outputs(self._inference_function.call(
#-> 1919 ctx,cancellation_manager=cancellation_manager))
# 1920 forward_backward = self._select_forward_and_backward_functions(
# 1921 args,#c:\program files\python37\lib\site-packages\tensorflow\python\eager\function.py in call(self,ctx,cancellation_manager)
# 558 inputs=args,# 559 attrs=attrs,#--> 560 ctx=ctx)
# 561 else:
# 562 outputs = execute.execute_with_cancellation(
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name,num_outputs,inputs,attrs,name)
# 58 ctx.ensure_initialized()
# 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle,device_name,op_name,#---> 60 inputs,num_outputs)
# 61 except core._NotOkStatusException as e:
# 62 if name is not None:
#InvalidArgumentError: Only tensors with ranks between 1 and 5 are currently supported. Tensor rank: 0
# [[{{node gradient_tape/SparseTensorDenseAdd_1/pfor/while/body/_56/gradient_tape/SparseTensorDenseAdd_1/pfor/while/SparseTensorDenseAdd}}]] [Op:__inference_f_6235]
#Function call stack:
#f
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)