Onnxruntime 错误 9:未实现 LeakyRelu

问题描述

我已将一个简单的 pytorch 模型转换为 onnx 格式,但未能通过 onnxruntime 在单独的文件中加载和评估模型。它给出了错误消息:

NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node LeakyRelu_1:LeakyRelu(6)

但文档清楚地列出了leakyrelu 运算符:https://github.com/microsoft/onnxruntime/blob/master/docs/OperatorKernels.md

转换脚本超级简单。

import torch.onnx
import torch

torch.set_default_dtype(torch.float64)

Model = torch.load("Test.pth")
inputs = torch.randn(1,2)
torch.onnx.export(Model,inputs,"Test.onnx",export_params=True)

加载转换模型的脚本也很简单,

import onnx
import onnxruntime as rt
import numpy as np


onnx_model = onnx.load("Test.onnx")
onnx.checker.check_model(onnx_model)

#print(onnx_model.graph)
print("[Graph Input] name: {},shape: {}".format(onnx_model.graph.input[0].name,[dim.dim_value for dim in onnx_model.graph.input[0].type.tensor_type.shape.dim]))
print("[Graph Output] name: {},shape: {}".format(onnx_model.graph.output[0].name,[dim.dim_value for dim in onnx_model.graph.output[0].type.tensor_type.shape.dim]))


print(onnx.helper.printable_graph(onnx_model.graph))

sess = rt.InferenceSession("Test.onnx")

打印出来的图是

graph torch-jit-export (
  %input.1[DOUBLE,1x2]
) initializers (
  %0.weight[DOUBLE,228x2]
  %0.bias[DOUBLE,228]
  %2.weight[DOUBLE,70x228]
  %2.bias[DOUBLE,70]
  %4.weight[DOUBLE,1x70]
  %4.bias[DOUBLE,1]
) {
  %7 = Gemm[alpha = 1,beta = 1,transB = 1](%input.1,%0.weight,%0.bias)
  %8 = LeakyRelu[alpha = 0.00999999977648258](%7)
  %9 = Gemm[alpha = 1,transB = 1](%8,%2.weight,%2.bias)
  %10 = LeakyRelu[alpha = 0.00999999977648258](%9)
  %11 = Gemm[alpha = 1,transB = 1](%10,%4.weight,%4.bias)
  return %11
}

我想知道是否已经实现了leakyrelu,或者我刚刚在转换中遗漏了一些东西。感谢您的帮助!

解决方法

根据文档 https://github.com/microsoft/onnxruntime/blob/master/docs/OperatorKernels.md,LeakyRelu 仅针对类型 float(32 位)实现,而您拥有 double(64 位)。

您可以尝试在 PyTorch 代码中的 LeakyRely 之前转换为 32 位浮点数。并且可能会在 ONNX 运行时 Github 上创建一个问题,以添加对 LeakyRelu 的双重支持。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...