如何使用 pybind11 .so 链接所有 PyTorch?

问题描述

我有一个使用 pytorch c++ api 的 pybind11 c++ 项目:

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <math.h>
#include <torch/torch.h>

...

void f()
{
...
   torch::Tensor dynamic_parameters = torch::full({1},/*value=*/0.5,torch::dtype(torch::kFloat64).requires_grad(true));
   torch::optim::SGD optimizer({dynamic_parameters},/*lr=*/0.01);
...
}

PYBIND11_MODULE(reson8,m)
{
    m.def("my_function",&my_function,"");
}

我使用 distutils 将其编译为可以导入 Python 的 .so

from distutils.core import setup,Extension

def configuration(parent_package='',top_path=None):
      import numpy
      from numpy.distutils.misc_util import Configuration
      from numpy.distutils.misc_util import get_info

      #Necessary for the half-float d-type.
      info = get_info('npymath')

      config = Configuration('',parent_package,top_path)
      config.add_extension('reson8',['reson8.cpp'],extra_info=info,include_dirs=["/home/ian/anaconda3/lib/python3.7/site-packages/pybind11/include","/home/ian/anaconda3/lib/python3.8/site-packages/pybind11/include","/home/ian/dev/hedgey/Engine/lib/libtorch/include","/home/ian/dev/hedgey/Engine/lib/libtorch/include/torch/csrc/api/include"])

      return config


if __name__ == "__main__":
      from numpy.distutils.core import setup
      setup(configuration=configuration)

它编译没有错误,但是在 python 中运行“import reson8”时我得到这个错误

importerror: undefined symbol: _ZTVN5torch5optim9OptimizerE

我不确定是不是 pytorch 没有链接到我的 so(虽然 .so 是 10mb,如果不包含 pytorch 就相当大了,但也许所有 pybind11 .so 文件都很大。)

我该如何解决这个问题?

解决方法

我最终发现我需要使用 Anaconda 版本的 torchlib 而不是我自己的,以及使用 Torch 的 CppExtension。这是我的工作 setup.py:

from distutils.core import setup,Extension
from torch.utils.cpp_extension import BuildExtension,CppExtension

def configuration(parent_package='',top_path=None):
      import numpy
      from numpy.distutils.misc_util import Configuration
      from numpy.distutils.misc_util import get_info

      #Necessary for the half-float d-type.
      info = get_info('npymath')

      config = Configuration('',parent_package,top_path)

      config.ext_modules.append(CppExtension(
                name='reson8',sources=['reson8.cpp'],extra_info=info,extra_compile_args=['-g','-D_GLIBCXX_USE_CXX11_ABI=0'],extra_ldflags=['-ltorch_python'],include_dirs=["/home/ian/anaconda3/lib/python3.7/site-packages/pybind11/include","/home/ian/anaconda3/lib/python3.8/site-packages/pybind11/include","/home/ian/anaconda3/lib"
                                          ]
                                ))

      return config


if __name__ == "__main__":
      from numpy.distutils.core import setup
      setup(configuration=configuration)