如何使用 PyTorch 快速反转排列?

问题描述

我对如何快速恢复由排列打乱的数组感到困惑。

示例 1:

  • [x,y,z]P: [2,1] 洗牌,我们将获得 [z,x,y]
  • 对应的逆应该是P^-1: [1,2,0]

示例 2:

  • [a,b,c,d,e,f]P: [5,1,4,3] 洗牌,然后我们会得到 [f,a,d]
  • 对应的逆应该是P^-1: [2,3,5,0]

我基于矩阵乘法编写了以下代码(置换矩阵的转置是它的逆),但是当我在模型训练中使用这种方法时,它太慢了。有没有更快的实现?

import torch

n = 10
x = torch.Tensor(list(range(n)))
print('Original array',x)

random_perm_indices = torch.randperm(n).long()
perm_matrix = torch.eye(n)[random_perm_indices].t()
x = x[random_perm_indices]
print('Shuffled',x)

restore_indices = torch.Tensor(list(range(n))).view(n,1)
restore_indices = perm_matrix.mm(restore_indices).view(n).long()
x = x[restore_indices]
print('Restored',x)

解决方法

我在 PyTorch Forum 中得到了解决方案。

>>> import torch
>>> torch.__version__
'1.7.1'
>>> p1 = torch.tensor ([2,1])
>>> torch.argsort (p1)
tensor([1,2,0])
>>> p2 = torch.tensor ([5,1,4,3])
>>> torch.argsort (p2)
tensor([2,3,5,0])