问题描述
我是火炬新手。
当我学习如下火炬教程时,
我很好奇tensor.item() 在哪里定义。
import torch
a = torch.tensor([1])
print( a.item() ) # it works without problem.
为了找到我不知道的东西,
首先,我使用了 VScode。但我明白了。
enter image description here
不是这样的
enter image description here
第二,我在the torch Github
搜索了“def item()”
但是,我找不到。 T^T
您能告诉我 tensor.item() 在 the Github 的何处定义吗?
或
Class_torch.tensor's Menber functions(Method function) 在哪里定义??
解决方法
简答:
torch/csrc/autograd/python_variable_indexing.cpp:268
长答案:
我希望你喜欢 C++。 ;)
首先要知道 item()
不是(通常)大多数 Python 类中的方法。相反,为了方便起见,Python 将对 item()
的调用转换为其他底层方法,例如 __getitem__()
。知道:
class Tensor
在 torch/tensor.py:40 处定义。
Torch 的大部分底层计算密集型功能都是用 C 和 C++ 实现的,包括 Tensor
。 “class Tensor”基于 Torch._C.TensorBase
,它通过来自 torch/csrc/autograd/python_variable.cpp:812
THPVariableType
是给 Python 的映射,描述了 Python 对象上可用的 C++ 函数。它在 torch/csrc/autograd/python_variable 处定义。与您相关的部分是 tp_as_mapping
条目(第 752 行),它为实现映射协议的对象提供函数——基本上是类 Pyton 数组的对象(Python Documentation)
第 725 行的 THPVariable_as_mapping
结构提供了映射方法。第二个变量提供了用于通过索引获取项目的下标函数(Python documentation)
因此,C++函数TPHVariable_getitem
提供了Torch.Tensor.item()
的实现,其定义在torch/csrc/autograd/python_variable_indexing