PyTorch C ++ API中的`randperm`不应该返回默认类型为int的张量吗?

问题描述

当我尝试使用C ++ PyTorch API使用randperm生成置换整数索引的列表时,所得张量的元素类型为cpuFloatType{10}而不是整数类型:

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES); 
cout << shuffled_indices << endl;

返回

 9
 3
 8
 6
 2
 5
 4
 7
 1
 0
[ cpuFloatType{10} ]

不能将其用于张量索引,因为元素类型是float而不是整数类型。当tryig使用my_tensor.index(shuffled_indices)

terminate called after throwing an instance of 'c10::IndexError'
  what():  tensors used as indices must be long,byte or bool tensors

环境:

  • Arch Linux上的python-pytorch版本1.6.0-2
  • g ++(GCC)10.1.0

为什么会这样?

解决方法

这是因为您使用割炬创建的任何张量的默认类型始终为float。否则,您必须使用TensorOptions参数struct来指定它:

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES,torch::TensorOptions().dtype(at::kLong));
cout << shuffled_indices.dtype() << endl;
>>> long