问题描述
为什么torch::Tensor::is_same
未能通过以下断言?使用C ++ PyTorch API将张量写入文件,然后再次读取到另一个张量,is_same
比较两个张量:
torch::Tensor x_sequence = torch::linspace(0,M_PI,1000);
torch::save(x_sequence,"x_sequence.dat");
torch::Tensor x_read;
torch::load(x_read,"x_sequence.dat");
assert(x_read.is_same(x_sequence));
结果是:
int main(int,char**): Assertion `x_read.is_same(x_sequence)' Failed.
使用
- Arch Linux上的python-pytorch版本1.6.0-2
- g ++(GCC)10.1.0
解决方法
torch::Tensor::is_same(const torch::Tensor& other)
被定义为here。请务必注意,Tensor
实际上是基础TensorImpl
类(实际上保存数据)上的指针。
因此,当您调用is_same
时,实际上检查的是您的指针是否相同,即您的2个张量是否指向相同的基础内存。这是一个很好理解的简单示例:
auto x = torch::randn({4,4});
auto copy = x;
auto clone = x.clone();
std::cout << x.is_same(copy) << " " << x.is_same(clone) << std::endl;
>>> 0 1
在这里,对clone
的调用强制pytorch将数据复制到另一个存储位置。因此,指针是不同的,is_same
返回false。
如果要实际比较这些值,则别无选择,只能计算两个张量之间的差并计算该差接近0的程度。