问题描述
我想使用 BoolTensor 索引在 Pytorch 中对多维张量进行切片。我期望索引张量,索引为真的部分被保留,而索引为假的部分被切掉。
我的代码就像
import torch
a = torch.zeros((5,50,5,50))
tr_indices = torch.zeros((50),dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices
print(a[:,tr_indices].shape)
print(a[:,tr_indices,:,val_indices].shape)
我希望 a[:,val_indices]
的形状为 [5,25,25]
,但它返回 [25,5]
。结果是
torch.Size([5,50])
torch.Size([25,5])
我很困惑。谁能解释一下原因?
解决方法
PyTorch 继承了其高级索引行为 from Numpy。像这样切片两次应该可以达到您想要的输出:
a[:,tr_indices][...,val_indices]