问题描述
我对如何快速恢复由排列打乱的数组感到困惑。
示例 1:
-
[x,y,z]
由P: [2,1]
洗牌,我们将获得[z,x,y]
- 对应的逆应该是
P^-1: [1,2,0]
示例 2:
-
[a,b,c,d,e,f]
被P: [5,1,4,3]
洗牌,然后我们会得到[f,a,d]
- 对应的逆应该是
P^-1: [2,3,5,0]
我基于矩阵乘法编写了以下代码(置换矩阵的转置是它的逆),但是当我在模型训练中使用这种方法时,它太慢了。有没有更快的实现?
import torch
n = 10
x = torch.Tensor(list(range(n)))
print('Original array',x)
random_perm_indices = torch.randperm(n).long()
perm_matrix = torch.eye(n)[random_perm_indices].t()
x = x[random_perm_indices]
print('Shuffled',x)
restore_indices = torch.Tensor(list(range(n))).view(n,1)
restore_indices = perm_matrix.mm(restore_indices).view(n).long()
x = x[restore_indices]
print('Restored',x)
解决方法
我在 PyTorch Forum 中得到了解决方案。
>>> import torch
>>> torch.__version__
'1.7.1'
>>> p1 = torch.tensor ([2,1])
>>> torch.argsort (p1)
tensor([1,2,0])
>>> p2 = torch.tensor ([5,1,4,3])
>>> torch.argsort (p2)
tensor([2,3,5,0])