问题描述
如何索引具有n个维度的张量t
和m个t
的最后一个维度?对于尺寸m之前的所有尺寸,index
张量的形状等于张量t
。换句话说,我想索引张量的中间维度,同时保留选定索引的以下所有维度。
例如,假设我们有两个张量:
t = torch.randn([3,5,2]) * 10
index = torch.tensor([[1,3],[0,4],[3,2]]).long()
带有t:
tensor([[[ 15.2165,-7.9702],[ 0.6646,5.2844],[-22.0657,-5.9876],[ -9.7319,11.7384],[ 4.3985,-6.7058]],[[-15.6854,-11.9362],[ 11.3054,3.3068],[ -4.7756,-7.4524],[ 5.0977,-17.3831],[ 3.9152,-11.5047]],[[ -5.4265,-22.6456],[ 1.6639,10.1483],[ 13.2129,3.7850],[ 3.8543,-4.3496],[ -8.7577,-12.9722]]])
然后我想要的输出将具有(3,2,2)
的形状,并且是:
tensor([[[ 0.6646,11.7384]],[[ 3.8543,3.7850]]])
另一个示例是我有一个形状为t
的张量(40,10,6,2)
和一个形状为(40,3)
的索引张量。这应该查询张量t
的维度3,并且预期的输出形状将为(40,3,2)
。
如何在不使用循环的情况下以通用方式实现这一目标?
解决方法
在这种情况下,您可以执行以下操作:
t[torch.arange(t.shape[0]).unsqueeze(1),index,...]
完整代码:
import torch
t = torch.tensor([[[ 15.2165,-7.9702],[ 0.6646,5.2844],[-22.0657,-5.9876],[ -9.7319,11.7384],[ 4.3985,-6.7058]],[[-15.6854,-11.9362],[ 11.3054,3.3068],[ -4.7756,-7.4524],[ 5.0977,-17.3831],[ 3.9152,-11.5047]],[[ -5.4265,-22.6456],[ 1.6639,10.1483],[ 13.2129,3.7850],[ 3.8543,-4.3496],[ -8.7577,-12.9722]]])
index = torch.tensor([[1,3],[0,4],[3,2]]).long()
output = t[torch.arange(t.shape[0]).unsqueeze(1),...]
# tensor([[[ 0.6646,# [ -9.7319,11.7384]],#
# [[-15.6854,# [ 3.9152,#
# [[ 3.8543,# [ 13.2129,3.7850]]])