问题描述
在 Detectron2 训练了一个模块后,我尝试将模型导出到 TorchScript, 然后我得到以下错误:
无法导出 Python 函数调用“_ScaleGradient”。在导出之前删除对 Python 函数的调用 >。您是否忘记添加 @script 或 @script_method 注释?如果这是一个 > nn.ModuleList,则将其添加到 __constants__
我发现代码在detectron2/modeling/roi_heads/cascade_rcnn.py
class _ScaleGradient(Function):
@staticmethod
def forward(ctx,input,scale):
ctx.scale = scale
return input
@staticmethod
def backward(ctx,grad_output):
return grad_output * ctx.scale,None
所以我将@statcmethod annos 更改为@torch.jit.script_method,之后,我收到了“'ScriptMethodStub' object is not callable”错误。
我不熟悉 torchscript,如何解决这个问题?
提前致谢。
解决方法
在推理阶段似乎不需要那个 _ScaleGradient 方法,所以我只是将以下代码添加到 cacasde_rcnn.py
if self.training:
#call _ScaleGradient.apply
else:
#don't call _ScaleGradient.apply