Tensorflow:如何在不包装的情况下使用tf.roll?

问题描述

我想独立移动二维张量的列或行,如:

a = tf.constant([[1,2,3],[4,5,6]])
shift = tf.constant([2,-1])
b = shift_fn(a,shift)

这给了我

b = [[0,1],[5,6,0]]

我发现tf.roll()可以做类似的事情,但是会包装元素。如何使用它填充零?

解决方法

一个不太好的解决方案是先使用tf.pad填充张量,然后在tf.roll内使用tf.map_fn独立移动每一行(或列)的填充张量。最后,您可以对结果进行适当的分割。例如:

a = tf.constant([[1,2,3],[4,5,6]])
shift = tf.constant([2,-1])

cols = a.shape[1]
paddings = tf.constant([[0,0],[cols,cols]])
padded_a = tf.pad(a,paddings)

tf.map_fn(
    lambda x: tf.roll(x[0],x[1],axis=-1),elems=(padded_a,shift),# The following argument is required here
    fn_output_signature=a.dtype
)[:,cols:cols*2]

"""
<tf.Tensor: shape=(2,3),dtype=int32,numpy=
array([[0,1],[5,6,0]],dtype=int32)>
"""

或者,为了帮助节省一些内存,可以在传递给fn的{​​{1}}函数内部完成填充和切片。

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...