与从文件读取的相同张量进行比较时,为什么C ++ PyTorch API中的“ is_same”会失败?

问题描述

为什么torch::Tensor::is_same未能通过以下断言?使用C ++ PyTorch API将张量写入文件,然后再次读取到另一个张量,is_same比较两个张量:

torch::Tensor x_sequence = torch::linspace(0,M_PI,1000);    
torch::save(x_sequence,"x_sequence.dat");
torch::Tensor x_read;
torch::load(x_read,"x_sequence.dat");
assert(x_read.is_same(x_sequence));  

结果是:

int main(int,char**): Assertion `x_read.is_same(x_sequence)' Failed.

使用

  • Arch Linux上的python-pytorch版本1.6.0-2
  • g ++(GCC)10.1.0

解决方法

torch::Tensor::is_same(const torch::Tensor& other)被定义为here。请务必注意,Tensor实际上是基础TensorImpl类(实际上保存数据)上的指针。

因此,当您调用is_same时,实际上检查的是您的指针是否相同,即您的2个张量是否指向相同的基础内存。这是一个很好理解的简单示例:

auto x = torch::randn({4,4});
auto copy = x;
auto clone = x.clone();
std::cout << x.is_same(copy) << " " << x.is_same(clone) << std::endl;
>>> 0 1

在这里,对clone的调用强制pytorch将数据复制到另一个存储位置。因此,指针是不同的,is_same返回false。

如果要实际比较这些值,则别无选择,只能计算两个张量之间的差并计算该差接近0的程度。