处理稀疏张量时,TensorFlow GradienTape 因 InvalidArgumentError 崩溃

问题描述

在尝试评估通过稀疏张量运算获得的结果的导数时,我遇到了奇怪的行为。如果我在对它们进行操作之前将所有稀疏输入炸毁为密集输入,则以下代码按预期工作(以下代码的第一部分),但是当我对稀疏张量执行相同操作时,它会因 InvalidArgumentError 崩溃。此外,我收到 while_loop 警告,如下所示。在实际问题中,当然涉及更多的操作和更大、更多的张量,我基本上必须以稀疏模式收集 c 的条目。任何人都可以(更多)理解这种行为吗?

import tensorflow as tf
import numpy as np
a=tf.SparseTensor(indices=[[0,0],[1,1]],values=np.array([1,1],dtype=np.float32),dense_shape=(2,2))
b=tf.SparseTensor(indices=[[0,0]],values=np.array([-1,-1],2))
#dense mode...
f1=tf.Variable([1,dtype=np.float32)
with tf.GradientTape() as gtape:
    c=tf.sparse.to_dense(a)*f1[0]+tf.sparse.to_dense(b)*f1[1]
print(gtape.jacobian(c,f1)) #... works fine
#sparse mode...
f2=tf.Variable([1,dtype=np.float32)
with tf.GradientTape() as gtape:
    c=tf.sparse.add(a*f2[0],b*f2[1],0)
    c=tf.sparse.to_dense(c)
print(gtape.jacobian(c,f2)) #... InvalidArgumentError

#WARNING:tensorflow:Using a while_loop for converting SparseAddGrad
#WARNING:tensorflow:Using a while_loop for converting SparseTensorDenseAdd
#WARNING:tensorflow:Using a while_loop for converting SparseTensorDenseAdd
#---------------------------------------------------------------------------
#InvalidArgumentError                      Traceback (most recent call last)
#<ipython-input-10-d449761ef6b2> in <module>
#     12     c=tf.sparse.add(a*f2[0],0)
#     13     c=tf.sparse.to_dense(c)
#---> 14 print(gtape.jacobian(c,f2)) #InvalidArgumentError
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\backprop.py in jacobian(self,target,sources,unconnected_gradients,parallel_iterations,experimental_use_pfor)
#   1187       try:
#   1188         output = pfor_ops.pfor(loop_fn,target_size,#-> 1189                                parallel_iterations=parallel_iterations)
#   1190       except ValueError as err:
#   1191         six.reraise(
#c:\program files\python37\lib\site-packages\tensorflow\python\ops\parallel_for\control_flow_ops.py in pfor(loop_fn,iters,fallback_to_while_loop,parallel_iterations)
#    203       def_function.run_functions_eagerly(False)
#    204     f = def_function.function(f)
#--> 205   outputs = f()
#    206   if functions_run_eagerly is not None:
#    207     def_function.run_functions_eagerly(functions_run_eagerly)
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self,*args,**kwds)
#    826     tracing_count = self.experimental_get_tracing_count()
#    827     with trace.Trace(self._name) as tm:
#--> 828       result = self._call(*args,**kwds)
#    829       compiler = "xla" if self._experimental_compile else "nonXla"
#    830       new_tracing_count = self.experimental_get_tracing_count()
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self,**kwds)
#    893       # If we did not create any variables the trace we have is good enough.
#    894       return self._concrete_stateful_fn._call_flat(
#--> 895           filtered_flat_args,self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
#    896 
#    897     def fn_with_cond(inner_args,inner_kwds,inner_filtered_flat_args):
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self,args,captured_inputs,cancellation_manager)
#   1917       # No tape is watching; skip to running the function.
#   1918       return self._build_call_outputs(self._inference_function.call(
#-> 1919           ctx,cancellation_manager=cancellation_manager))
#   1920     forward_backward = self._select_forward_and_backward_functions(
#   1921         args,#c:\program files\python37\lib\site-packages\tensorflow\python\eager\function.py in call(self,ctx,cancellation_manager)
#    558               inputs=args,#    559               attrs=attrs,#--> 560               ctx=ctx)
#    561         else:
#    562           outputs = execute.execute_with_cancellation(
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name,num_outputs,inputs,attrs,name)
#     58     ctx.ensure_initialized()
#     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle,device_name,op_name,#---> 60                                         inputs,num_outputs)
#     61   except core._NotOkStatusException as e:
#     62     if name is not None:
#InvalidArgumentError:  Only tensors with ranks between 1 and 5 are currently supported.  Tensor rank: 0
#    [[{{node gradient_tape/SparseTensorDenseAdd_1/pfor/while/body/_56/gradient_tape/SparseTensorDenseAdd_1/pfor/while/SparseTensorDenseAdd}}]] [Op:__inference_f_6235]
#Function call stack:
#f

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...