问题描述
我想独立移动二维张量的列或行,如:
a = tf.constant([[1,2,3],[4,5,6]])
shift = tf.constant([2,-1])
b = shift_fn(a,shift)
这给了我
b = [[0,1],[5,6,0]]
我发现tf.roll()
可以做类似的事情,但是会包装元素。如何使用它填充零?
解决方法
一个不太好的解决方案是先使用tf.pad
填充张量,然后在tf.roll
内使用tf.map_fn
独立移动每一行(或列)的填充张量。最后,您可以对结果进行适当的分割。例如:
a = tf.constant([[1,2,3],[4,5,6]])
shift = tf.constant([2,-1])
cols = a.shape[1]
paddings = tf.constant([[0,0],[cols,cols]])
padded_a = tf.pad(a,paddings)
tf.map_fn(
lambda x: tf.roll(x[0],x[1],axis=-1),elems=(padded_a,shift),# The following argument is required here
fn_output_signature=a.dtype
)[:,cols:cols*2]
"""
<tf.Tensor: shape=(2,3),dtype=int32,numpy=
array([[0,1],[5,6,0]],dtype=int32)>
"""
或者,为了帮助节省一些内存,可以在传递给fn
的{{1}}函数内部完成填充和切片。