问题描述
我有一个使用 pytorch c++ api 的 pybind11 c++ 项目:
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <math.h>
#include <torch/torch.h>
...
void f()
{
...
torch::Tensor dynamic_parameters = torch::full({1},/*value=*/0.5,torch::dtype(torch::kFloat64).requires_grad(true));
torch::optim::SGD optimizer({dynamic_parameters},/*lr=*/0.01);
...
}
PYBIND11_MODULE(reson8,m)
{
m.def("my_function",&my_function,"");
}
我使用 distutils 将其编译为可以导入 Python 的 .so:
from distutils.core import setup,Extension
def configuration(parent_package='',top_path=None):
import numpy
from numpy.distutils.misc_util import Configuration
from numpy.distutils.misc_util import get_info
#Necessary for the half-float d-type.
info = get_info('npymath')
config = Configuration('',parent_package,top_path)
config.add_extension('reson8',['reson8.cpp'],extra_info=info,include_dirs=["/home/ian/anaconda3/lib/python3.7/site-packages/pybind11/include","/home/ian/anaconda3/lib/python3.8/site-packages/pybind11/include","/home/ian/dev/hedgey/Engine/lib/libtorch/include","/home/ian/dev/hedgey/Engine/lib/libtorch/include/torch/csrc/api/include"])
return config
if __name__ == "__main__":
from numpy.distutils.core import setup
setup(configuration=configuration)
它编译没有错误,但是在 python 中运行“import reson8”时我得到这个错误:
importerror: undefined symbol: _ZTVN5torch5optim9OptimizerE
我不确定是不是 pytorch 没有链接到我的 so(虽然 .so 是 10mb,如果不包含 pytorch 就相当大了,但也许所有 pybind11 .so 文件都很大。)
我该如何解决这个问题?
解决方法
我最终发现我需要使用 Anaconda 版本的 torchlib 而不是我自己的,以及使用 Torch 的 CppExtension。这是我的工作 setup.py:
from distutils.core import setup,Extension
from torch.utils.cpp_extension import BuildExtension,CppExtension
def configuration(parent_package='',top_path=None):
import numpy
from numpy.distutils.misc_util import Configuration
from numpy.distutils.misc_util import get_info
#Necessary for the half-float d-type.
info = get_info('npymath')
config = Configuration('',parent_package,top_path)
config.ext_modules.append(CppExtension(
name='reson8',sources=['reson8.cpp'],extra_info=info,extra_compile_args=['-g','-D_GLIBCXX_USE_CXX11_ABI=0'],extra_ldflags=['-ltorch_python'],include_dirs=["/home/ian/anaconda3/lib/python3.7/site-packages/pybind11/include","/home/ian/anaconda3/lib/python3.8/site-packages/pybind11/include","/home/ian/anaconda3/lib"
]
))
return config
if __name__ == "__main__":
from numpy.distutils.core import setup
setup(configuration=configuration)