问题描述
我有一个 2D RaggedTensor,它由我想要的全张量每一行的索引组成,例如:
[
[0,4],[1,2,3],[5]
]
进入
[
[200,305,400,20,105],[200,315,401,167],7,402,]
给予
[
[200,20],[315,[105]
]
我怎样才能以最有效的方式实现这一点(最好只使用 tf
函数)?我相信像 gather_nd
这样的东西可以使用 RaggedTensors,但我不知道它是如何工作的。
解决方法
您可以将 tf.gather
与 batch_dims
关键字参数一起使用:
>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200,20],[315,401,[105]]>