Tensorflow:如何使用 Ragged Tensor 作为普通张量的索引?

问题描述

我有一个 2D RaggedTensor,它由我想要的全张量每一行的索引组成,例如:

[
    [0,4],[1,2,3],[5]
]

进入

[
    [200,305,400,20,105],[200,315,401,167],7,402,]

给予

[
    [200,20],[315,[105]
]

我怎样才能以最有效的方式实现这一点(最好只使用 tf 函数)?我相信像 gather_nd 这样的东西可以使用 RaggedTensors,但我不知道它是如何工作的。

解决方法

您可以将 tf.gatherbatch_dims 关键字参数一起使用:

>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200,20],[315,401,[105]]>