问题描述
当我尝试使用C ++ PyTorch API使用randperm
生成置换整数索引的列表时,所得张量的元素类型为cpuFloatType{10}
而不是整数类型:
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES);
cout << shuffled_indices << endl;
返回
9
3
8
6
2
5
4
7
1
0
[ cpuFloatType{10} ]
不能将其用于张量索引,因为元素类型是float而不是整数类型。当tryig使用my_tensor.index(shuffled_indices)
时
terminate called after throwing an instance of 'c10::IndexError'
what(): tensors used as indices must be long,byte or bool tensors
环境:
- Arch Linux上的python-pytorch版本1.6.0-2
- g ++(GCC)10.1.0
为什么会这样?
解决方法
这是因为您使用割炬创建的任何张量的默认类型始终为float
。否则,您必须使用TensorOptions
参数struct来指定它:
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES,torch::TensorOptions().dtype(at::kLong));
cout << shuffled_indices.dtype() << endl;
>>> long