如何在pytorch中使用索引张量索引中间维度?

问题描述

如何索引具有n个维度的张量t和m个index张量,从而保留t的最后一个维度?对于尺寸m之前的所有尺寸,index张量的形状等于张量t。换句话说,我想索引张量的中间维度,同时保留选定索引的以下所有维度。

例如,假设我们有两个张量:

t = torch.randn([3,5,2]) * 10
index = torch.tensor([[1,3],[0,4],[3,2]]).long()

带有t:

tensor([[[ 15.2165,-7.9702],[  0.6646,5.2844],[-22.0657,-5.9876],[ -9.7319,11.7384],[  4.3985,-6.7058]],[[-15.6854,-11.9362],[ 11.3054,3.3068],[ -4.7756,-7.4524],[  5.0977,-17.3831],[  3.9152,-11.5047]],[[ -5.4265,-22.6456],[  1.6639,10.1483],[ 13.2129,3.7850],[  3.8543,-4.3496],[ -8.7577,-12.9722]]])

然后我想要的输出将具有(3,2,2)的形状,并且是:

tensor([[[  0.6646,11.7384]],[[  3.8543,3.7850]]])

一个示例是我有一个形状为t的张量(40,10,6,2)一个形状为(40,3)的索引张量。这应该查询张量t的维度3,并且预期的输出形状将为(40,3,2)

如何在不使用循环的情况下以通用方式实现这一目标?

解决方法

在这种情况下,您可以执行以下操作:

t[torch.arange(t.shape[0]).unsqueeze(1),index,...]

完整代码:

import torch

t = torch.tensor([[[ 15.2165,-7.9702],[  0.6646,5.2844],[-22.0657,-5.9876],[ -9.7319,11.7384],[  4.3985,-6.7058]],[[-15.6854,-11.9362],[ 11.3054,3.3068],[ -4.7756,-7.4524],[  5.0977,-17.3831],[  3.9152,-11.5047]],[[ -5.4265,-22.6456],[  1.6639,10.1483],[ 13.2129,3.7850],[  3.8543,-4.3496],[ -8.7577,-12.9722]]])

index = torch.tensor([[1,3],[0,4],[3,2]]).long()

output = t[torch.arange(t.shape[0]).unsqueeze(1),...]

# tensor([[[  0.6646,#          [ -9.7319,11.7384]],# 
#         [[-15.6854,#          [  3.9152,# 
#         [[  3.8543,#          [ 13.2129,3.7850]]])