将detectron2模型导出到torchscript时出现“无法导出Python函数调用'_ScaleGradient'”

问题描述

在 Detectron2 训练了一个模块后,我尝试将模型导出到 TorchScript, 然后我得到以下错误

无法导出 Python 函数调用“_ScaleGradient”。在导出之前删除对 Python 函数调用 >。您是否忘记添加 @script 或 @script_method 注释?如果这是一个 > nn.ModuleList,则将其添加到 __constants__

我发现代码在detectron2/modeling/roi_heads/cascade_rcnn.py

class _ScaleGradient(Function):
    @staticmethod
    def forward(ctx,input,scale):
        ctx.scale = scale
        return input

    @staticmethod
    def backward(ctx,grad_output):
        return grad_output * ctx.scale,None

所以我将@statcmethod annos 更改为@torch.jit.script_method,之后,我收到了“'ScriptMethodStub' object is not callable”错误

我不熟悉 torchscript,如何解决这个问题?

提前致谢。

解决方法

在推理阶段似乎不需要那个 _ScaleGradient 方法,所以我只是将以下代码添加到 cacasde_rcnn.py

if self.training:
    #call _ScaleGradient.apply
else:
    #don't call _ScaleGradient.apply

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...