为什么我们在Pytorch张量上调用.numpy之前先调用.detach?

问题描述

It has been firmly established that my_tensor.detach().numpy() is the correct way to get a numpy array from a torch tensor.

我试图更好地理解为什么。

accepted answer中,刚刚链接的问题中,Blupon指出:

除了实际值定义之外,您还需要将张量转换为另一个不需要梯度的张量。

在第一次讨论中,他链接到albanD州:

这是预期的行为,因为移至numpy会破坏图形,因此不会计算梯度。

如果您实际上不需要渐变,则可以显式.detach()需要grad的张量,以获得与不需要grad相同的张量。然后可以将另一个张量转换为numpy数组。

在他链接的第二次讨论中,apaszke写道:

变量不能转换为numpy,因为它们是围绕张量的包装器,用于保存操作历史记录,并且numpy没有此类对象。您可以使用.data属性来检索变量持有的张量。然后,这应该起作用:var.data.numpy()。

我已经研究了PyTorch自动分化库的内部工作原理,但我仍然对这些答案感到困惑。为什么它会破坏图形以移至numpy?是否因为在numpy数组上进行的任何操作都不会在autodiff图中进行跟踪?

什么是变量?它与张量有什么关系?

我觉得这里需要一个彻底的高质量Stack-Overflow答案,向尚未了解自动分化的PyTorch新用户解释原因。

特别是,我认为通过图形说明图形并显示此示例中断开连接的方式会有所帮助:

import torch

tensor1 = torch.tensor([1.0,2.0],requires_grad=True)

print(tensor1)
print(type(tensor1))

tensor1 = tensor1.numpy()

print(tensor1)
print(type(tensor1))

解决方法

我认为在这里最关键的要点是torch.tensornp.ndarray之间的差异
虽然两个对象都用于存储n维矩阵(又名"Tensors"),但是torch.tensors具有一个附加的“层”-用于存储导致​​相关n维矩阵的计算图。

因此,如果您只对在矩阵np.ndarraytorch.tensor上执行数学运算的高效简便方法感兴趣,可以互换使用。

但是,torch.tensor被设计为在gradient descent优化的上下文中使用,因此它们不仅拥有带有数值的张量,而且(更重要的是)拥有导致这些值。然后,使用此计算图(使用chain rule of derivatives)来计算损失函数的导数,而每个函数都是用来计算损失的。

如前所述,np.ndarray对象不具有此额外的“计算图”层,因此,在将torch.tensor转换为np.ndarray时,必须明确地使用detach()命令删除张量的计算图。


计算图
在您的comments中,这个概念似乎有点含糊。我将尝试用一个简单的例子来说明它。
考虑两个(向量)变量xw的简单函数:

x = torch.rand(4,requires_grad=True)
w = torch.rand(4,requires_grad=True)

y = x @ w  # inner-product of x and w
z = y ** 2  # square the inner product

如果我们只对z的值感兴趣,则不必担心任何图形,我们只需从输入x和{{1 }},先计算w,然后计算y

但是,如果我们不太在意z的值,而是想问一个问题“最小化的z是什么em> w给定的z”?
要回答这个问题,我们需要计算x w.r.t z导数
我们该怎么做?
使用chain rule,我们知道w。也就是说,要计算dz/dw = dz/dy * dy/dw相对z的梯度,我们需要将w的{​​{3}}返回到z,以计算 gradient 在我们跟踪backwardwz的步骤时的每一步操作。我们追溯的这个“路径”是w计算图,它告诉我们如何计算导致z的输入的z的导数:

z

我们现在可以检查z.backward() # ask pytorch to trace back the computation of z 的{​​{1}}的梯度:

z

请注意,这完全等于

w

w.grad # the resulting gradient of z w.r.t w tensor([0.8010,1.9746,1.5904,1.0408]) 2*y*x tensor([0.8010,1.0408],grad_fn=<MulBackward0>) 开始。

沿路径的每个张量存储其对计算的“贡献”:

dz/dy = 2*y

dy/dw = x

如您所见,z tensor(1.4061,grad_fn=<PowBackward0>) y tensor(1.1858,grad_fn=<DotBackward>) 不仅存储yz的“正向”值,还存储计算图-追溯从<x,w>(输出)到y**2(输入)的梯度时,计算导数(使用链式规则)所需的grad_fn

这些zw的重要组成部分,没有它们,就无法计算复杂函数的导数。但是,grad_fn个根本不具备此功能,并且也没有此信息。

请参阅back,以获取有关使用torch.tensors函数追溯导数的更多信息。


由于np.ndarraybackwrd()都有一个公用的“层”,用于存储n-d个数字数组,因此pytorch使用相同的存储器来节省内存:

this answer
以NumPy ndarray的形式返回np.ndarray张量。该张量和返回的ndarray 共享相同的基础存储空间。自张量的变化将反映在ndarray中,反之亦然。

另一个方向也以相同的方式起作用:

numpy() → numpy.ndarray
从numpy.ndarray创建张量。
返回的张量和ndarray 共享相同的内存。张量的修改将反映在ndarray中,反之亦然。

因此,从torch.tensor创建self时,反之亦然,两个对象 reference 都在内存中位于相同的基础存储中。由于np.array不存储/表示与数组关联的计算图,因此在共享numpy和割炬希望引用同一张量时,应使用torch.tensor明确删除此图


请注意,如果出于某种原因,如果您希望仅将pytorch用于数学运算而不进行反向传播,则可以使用torch.from_numpy(ndarray) → Tensor上下文管理器,在这种情况下,不会创建计算图,而{{1} }和np.ndarray可以互换使用。

detach()
,

我问,为什么它会破坏图形以移动到numpy?是因为在numpy数组上进行的任何操作都不会在autodiff图中被跟踪吗?

是的,新张量将不会通过#include<iostream> #include<string> using namespace std; int main(){ string str; cout << "Enter a string: "; getline(cin,str); int length = str.length(); string temp; int k = 0; for(int i = length-1; i>=0; i--){ temp[++k] = str[i]; } cout<<temp; return 0; 连接到旧张量,因此对新张量的任何操作都不会将梯度传回旧张量。

撰写grad_fn只是说:“我将基于numpy数组中该张量的值进行一些非跟踪计算。”

深入学习(d2l)教科书has a nice section describing the detach() method,尽管没有讨论为什么分离在转换为numpy数组之前有意义。


感谢jodag帮助回答了这个问题。如他所说,变量已过时,因此我们可以忽略该评论。

我认为到目前为止我能找到的最佳答案是在jodag's doc link中:

要停止张量跟踪历史记录,可以调用.detach()将其与计算历史记录分离,并防止跟踪将来的计算。

以及我在问题中引用的albanD言论:

如果您实际上不需要渐变,则可以显式.detach()需要grad的张量,以获得与不需要grad相同的张量。然后可以将另一个张量转换为numpy数组。

换句话说,my_tensor.detach().numpy()方法的意思是“我不要渐变”,并且不可能通过detach操作来跟踪渐变(毕竟,这就是PyTorch张量的用途! )

,

这是张量-> numpy数组连接的小展示:

#include <iostream>

int Main() {
  std::cout << "Hello World!";
  return 0;
}

输出:

import torch
tensor = torch.rand(2)
numpy_array = tensor.numpy()
print('Before edit:')
print(tensor)
print(numpy_array)

tensor[0] = 10

print()
print('After edit:')
print('Tensor:',tensor)
print('Numpy array:',numpy_array)

第一个元素的值由张量和numpy数组共享。将其张量更改为10时,也会在numpy数组中更改它。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...