tf.tensor_scatter_nd_add是否可能

问题描述

以下使用tf.tensor_scatter_nd_add的简单示例给我带来了麻烦。

B = tf.tensor_scatter_nd_add(A,indices,updates)

张量A为(1,4,4)

A = [[[1. 1. 1. 1.],[1. 1. 1. 1.],[1. 1. 1. 1.]]]

所需结果是张量B:

B = [[[1. 1. 1. 1.],[1. 2. 3. 1.],[1. 4. 5. 1.],[1. 1. 1. 1.]]]

即我想将此较小的张量添加到张量A的4个内部元素中

updates = [[[1,2],[3,4]]]

Tensorflow 2.1.0。我尝试了多种构建索引的方法调用tensor_scatter_nd_add将返回错误,指出内部尺寸不匹配。

更新张量是否需要与A相同的形状?

解决方法

Planaria,

尝试传递索引并按以下方式更新:以形状(n)更新,以形状(n,3)更新索引,其中n是已更改项目的数量。 索引应指向您要更改的单个单元格:

A = tf.ones((1,4,),dtype=tf.dtypes.float32)
updates =  tf.constant([1.,2.,3.,4])
indices = tf.constant([[0,1,1],[0,2],2,2]])
tf.tensor_scatter_nd_add(A,indices,updates)

<tf.Tensor: shape=(1,4),dtype=float32,numpy=
array([[[1.,1.,1.],[1.,4.,5.,1.]]],dtype=float32)>