在 Pytorch 中,如何使用 BoolTensor 掩码将张量切片跨越多个暗淡?

问题描述

我想使用 BoolTensor 索引在 Pytorch 中对多维张量进行切片。我期望索引张量,索引为真的部分被保留,而索引为假的部分被切掉。

我的代码就像

import torch
a = torch.zeros((5,50,5,50))

tr_indices = torch.zeros((50),dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices

print(a[:,tr_indices].shape)
print(a[:,tr_indices,:,val_indices].shape)

我希望 a[:,val_indices] 的形状为 [5,25,25],但它返回 [25,5]。结果是

torch.Size([5,50])
torch.Size([25,5])

我很困惑。谁能解释一下原因?

解决方法

PyTorch 继承了其高级索引行为 from Numpy。像这样切片两次应该可以达到您想要的输出:

a[:,tr_indices][...,val_indices]