使用 C++ 解析 ONNX 模型使用 C++ 从 onnx 模型中提取层、输入和输出形状

问题描述

我正在尝试从 onnx 模型中提取数据,例如输入层、输出层及其形状。我知道有 python 接口可以做到这一点。我想做一些类似于这个 code 但在 C++ 中的事情。我还粘贴了链接中的代码。我在 python 中尝试过它,它对我有用。我想知道是否有 C++ API 来做同样的事情。

import onnx

model = onnx.load(r"model.onnx")

# The model is represented as a protobuf structure and it can be accessed
# using the standard python-for-protobuf methods

# iterate through inputs of the graph
for input in model.graph.input:
    print (input.name,end=": ")
    # get type of input tensor
    tensor_type = input.type.tensor_type
    # check if it has a shape:
    if (tensor_type.HasField("shape")):
        # iterate through dimensions of the shape:
        for d in tensor_type.shape.dim:
            # the dimension may have a definite (integer) value or a symbolic identifier or neither:
            if (d.HasField("dim_value")):
                print (d.dim_value,end=",")  # kNown dimension
            elif (d.HasField("dim_param")):
                print (d.dim_param,")  # unkNown dimension with symbolic name
            else:
                print ("?",")  # unkNown dimension with no name
    else:
        print ("unkNown rank",end="")
    print()

我也是 C++ 新手,请帮助我。

解决方法

ONNX 格式本质上是一个 protobuf,因此可以在任何 protoc 编译器支持的语言中打开。

如果是 C++

  1. 获取 onnx proto 文件 (onnx repo)
  2. 使用 protoc --cpp_out=. onnx.proto3 命令编译它。它将生成 onnx.proto3.pb.cconnx.proto3.pb.h 文件
  3. 链接 protobuf 库(可能是 protobuf-lite)、生成的 cpp 文件和以下代码:
#include <fstream>
#include <cassert>

#include "onnx.proto3.pb.h"

void print_dim(const ::onnx::TensorShapeProto_Dimension &dim)
{
  switch (dim.value_case())
  {
  case onnx::TensorShapeProto_Dimension::ValueCase::kDimParam:
    std::cout << dim.dim_param();
    break;
  case onnx::TensorShapeProto_Dimension::ValueCase::kDimValue:
    std::cout << dim.dim_value();
    break;
  default:
    assert(false && "should never happen");
  }
}

void print_io_info(const ::google::protobuf::RepeatedPtrField< ::onnx::ValueInfoProto > &info)
{
  for (auto input_data: info)
  {
    auto shape = input_data.type().tensor_type().shape();
    std::cout << "  " << input_data.name() << ":";
    std::cout << "[";
    if (shape.dim_size() != 0)
    {
      int size = shape.dim_size();
      for (int i = 0; i < size - 1; ++i)
      {
        print_dim(shape.dim(i));
        std::cout << ",";
      }
      print_dim(shape.dim(size - 1));
    }
    std::cout << "]\n";
  }
}

int main(int argc,char **argv)
{
  std::ifstream input("mobilenet.onnx",std::ios::ate | std::ios::binary); // open file and move current position in file to the end

  std::streamsize size = input.tellg(); // get current position in file
  input.seekg(0,std::ios::beg); // move to start of file

  std::vector<char> buffer(size);
  input.read(buffer.data(),size); // read raw data

  onnx::ModelProto model;
  model.ParseFromArray(buffer.data(),size); // parse protobuf

  auto graph = model.graph();

  std::cout << "graph inputs:\n";
  print_io_info(graph.input());

  std::cout << "graph outputs:\n";
  print_io_info(graph.output());
  return 0;
}

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...