多维张量的前K个索引

问题描述

我有一个2D张量,我想获得前k个值的索引。我知道pytorch's topk函数。 pytorch的topk函数的问题是,它在某个维度上计算topk值。我想在两个维度上都获得topk值。

例如以下张量

a = torch.tensor([[4,9,7,4,0],[8,1,3,[9,8,8],[0,4]])

pytorch的topk函数将为我提供以下内容

values,indices = torch.topk(a,3)

print(indices)
# tensor([[1,2,#        [0,1],4],#        [1,3],4]])

但是我想得到以下内容

tensor([[0,[2,[3,1]])

这是2D张量中的9的索引。

有什么方法可以使用pytorch做到这一点?

解决方法

v,i = torch.topk(a.flatten(),3)
print (np.array(np.unravel_index(i.numpy(),a.shape)).T)

输出:

[[3 1]
 [2 0]
 [0 1]]
  1. 平整并找到前k个
  2. 使用unravel_index将一维索引转换为二维索引
,

您可以flatten原始张量,应用topk,然后使用以下类似的方法将结果标量索引转换回多维索引:

def descalarization(idx,shape):
    res = []
    N = np.prod(shape)
    for n in shape:
        N //= n
        res.append(idx // N)
        idx %= N
    return tuple(res)

示例:

torch.tensor([descalarization(k,a.size()) for k in torch.topk(a.flatten(),5).indices])
# Returns 
# tensor([[3,1],#         [2,0],#         [0,#         [3,4],4]])
,

您可以进行一些矢量运算来根据需要进行过滤。在这种情况下,请不要使用topk。

print(a)
tensor([[4,9,7,4,[8,1,3,[9,8,8],[0,4]])

values,indices = torch.max(a,1)   # get max values,indices
temp= torch.zeros_like(values)     # temporary
temp[values==9]=1                  # fill temp where values are 9 (wished value)
seq=torch.arange(values.shape[0])  # create a helper sequence
new_seq=seq[temp>0]                # filter sequence where values are 9
new_temp=indices[new_seq]          # filter indices with sequence where values are 9
final = torch.stack([new_seq,new_temp],dim=1)  # stack both to get result

print(final)
tensor([[0,[2,[3,1]])