问题描述
当我运行下面的代码时,它会显示错误。
ImportError: torch.utils.ffi is deprecated. Please use cpp extensions instead.
我一直在网上寻找解决方案。问题是下面的代码在旧版本的火炬(0.4.1)中工作。我想知道是否可以修改或替换此代码以在新版本的pytorch中工作。
from torch.utils.ffi import _wrap_function
from ._nms import lib as _lib,ffi as _ffi
__all__ = []
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib,symbol)
if callable(fn):
locals[symbol] = _wrap_function(fn,_ffi)
else:
locals[symbol] = fn
__all__.append(symbol)
_import_symbols(locals())
解决方法
我正面临同样的问题,刚刚在以下位置看到了一些有用的信息:
- https://pytorch.org/tutorials/advanced/cpp_extension.html
- https://pytorch.org/docs/stable/cpp_extension.html
为避免 PyTorch 版本降级,您应该考虑使用以下库,同时在上述链接中查找更多详细信息:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension,CppExtension