向量化Pytorch张量索引操作

问题描述

我正在尝试对PyTorch中的操作进行矢量化处理,但是我不确定该怎么做。这是现在使用for循环的代码。 'm'是具有int键和1d张量作为值的字典。 输出掩码为2d。 L是层数,此循环可能是必需的。因此,我希望主要替换2个内部循环。我当时想以某种方式使用torch.gather,但没有成功

for l in range(L):
    mask = torch.zeros((m[l].shape[0],m[l-1].shape[0]))
    for i in range(m[l].shape[0]):
        for j in range(m[l-1].shape[0]):
            mask[i,j] = R[m[l-1][j],m[l][i]]
    masks.append(mask)

我将不胜感激!预先感谢。

解决方法

我想我自己找到了答案。如此处所述,您可以使用numpy高级索引:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html

然后将其归结为以下内容

for l in range(L):
    mask = R[m[l-1],m[l][:,np.newaxis]]
    masks.append(mask)

np.newaxis确保对每一行重复列索引。

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...