问题描述
如果您阅读了 jax source code,您会遇到一个叫做 xla_client
的东西。经常这样导入
from . import xla_client
这意味着 xla_client
是一个 python 模块,但我找不到任何具有该名称的文件或对该名称变量的引用。
我假设它与 https://pypi.org/project/jaxlib/ 相关,但这个包只是链接回 jax 源代码。
有人能帮我解答吗?
解决方法
您所指的文件存储在 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/python
让我进一步阐述:xla_client
部分是一个名为 xla_extension.so
的专门编译的 C++ 文件的包装器,例如参见
from . import xla_extension as _xla
以及在 _xla
中对 xla_config
的大量引用。这个文件的来源是https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc,我们知道这是因为它在https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/BUILD
pybind_extension(
name = "xla_extension",srcs = [
"xla.cc",],...