在 TensorFlow 中批量实现锯齿形展平 NxN 张量

问题描述

问题可以用zigzag scanning来描述。但是,我想知道是否有使用 TensorFlow 建议的 tf.tensor_scatter_nd_update 之类的实现的 TensorFlow 版本。

BxNxN 张量,其中 B 代表批次。

解决方法

我找到了一种使用 1x1 转换的解决方法。使用 numpy 生成一个常数置换卷积核(tf 不支持 Eager tensor assignment...),然后 在对其应用 tf.nn.conv2d 之前将张量(BxNxN)重塑为 Bx1x1x(NxN)。最后做一些reshape acrobat 把它弄平。