通过pytorch中的视图将索引转换回

问题描述

我有以下代码,其中我定义了一个零的火炬张量,将一项更改为等于 1,通过三个重塑函数

然后,在所有的转换之后,我获得了 1 的索引。我想知道是否有可能以某种方式使用 max_idx 和有关 permutations/.view 的信息来获得初始 B1 张量中的 1 索引(应该等于 1234)。

A1 = np.zeros(10*18*40*28)
A1[1234] = 1
A1 = A1.reshape(10,18,40,28)
B1 = torch.Tensor(A1)
print('B1: ',B1.shape,torch.nonzero(B1))
C1 = B1.permute(0,2,3,1)
print('C1: ',C1.shape,torch.nonzero(C1))
D1 = C1.contiguous().view(C1.shape[0],C1.shape[1],C1.shape[2],6)
print('D1: ',D1.shape,torch.nonzero(D1))
E1 = D1.contiguous().view(D1.shape[0],-1,6)
print('E1: ',E1.shape,torch.nonzero(E1))

max_idx = torch.nonzero(E1)

我很想听听有关如何尝试这样做的任何提示:)

解决方法

对于每个维度,您可以检查您所引用的值位于哪个索引中。找到正确的索引后,减去下一维之前的索引数量,并为较小的子数组解决相同的问题。 或者,您可以仅使用执行完全相同操作的函数“numpy.unravel_index”。

import numpy as np
import torch


A1 = np.zeros(10*18*40*28)
idx = 1234
A1[idx] = 1
A1 = A1.reshape(10,18,40,28)
B1 = torch.Tensor(A1)
print('B1: ',B1.shape,torch.nonzero(B1))

idx_temp = idx+0
idxB1 = np.zeros((B1.dim(),),dtype = int)
for i in range(B1.dim()):
    idxB1[i]  = idx_temp//np.prod(B1.shape[i+1:])
    idx_temp -= np.prod(B1.shape[i+1:])*idxB1[i]

idxB1np = np.unravel_index(idx,B1.shape)

print(f'idxB1 = {idxB1}')
print(f'idxB1np = {idxB1np}')

输出:

B1:  torch.Size([10,28]) tensor([[0,1,4,2]])
idxB1 = [0 1 4 2]
idxB1np = (0,2)