将torch转换为onnx的脚本

因为工作部署需要,经常需要将自己的torch模型转化为onnx格式进行部署,写了一个转化脚本

import torch


def model2onnx(model, arg, onnx_name, dynamic=False):
    r'''
    model:torch.module已经加载好的模型
    arg:输入参数
    onnx_name:输出onnx模型的名称
    dynamic:是否动态输入
    '''

    model = model
    args = arg  # 正常运行网络时输入的tensor或者image
    f = onnx_name  # 输出onnx的文件
    export_params = True  # 如果指定,所有参数都会被导出。如果你只想导出一个未训练的模型,就将此参数设置为False。在这种情况下,
    # 导出的模型将首先把所有parameters作为参arguments,顺序由model.state_dict().values()指定

    verbose = False  # 如果指定为True,将会输出被导出的轨迹的调试描述
    training = False  # 导出训练模型下的模型。目前,ONNX只面向推断模型的导出,所以一般不需要将该项设置为True
    input_names = ['inputs']  # 按顺序分配名称到图中的输入节点,如果不设置认为随机分配序号
    output_names = ['outputs']  # 按顺序分配名称到图中的输出节点,如果不设置认为随机分配序号
    do_constant_folding = True  # 设置为True表示开始常数折叠优化,推荐为True
    if dynamic == False:
        dynamic_axes = None  # 只对于dynamic shape的模型而言,输入是词典
    else:
        '''
        实际应用的时候输入图片的尺寸是不固定的,而且可能一次输入多种图片一起处理。我们可以通过指定dynamic_axes参数来导出动态输入的模型。
        dynamic_axes的参数是一个字典类型,字典的key就是输入或者输出的名字,对应key的value可以是一个字典或者列表,
        指定了输入或者输出的index以及对应的名字。比如想要让输入的index为0的维度表示动态的batch_size那么就指定{0: 'batch_size'}。
        同样的方法可以指定宽高所在的维度输出成动态的。
        '''
        dynamic_axes = {
            'inputs': {0: 'batch_size', 2: 'in_width', 3: 'in_height'},
            'outputs': {0: 'batch_size', 2: 'out_width', 3: 'out_height'}
        }

    print("start export onnx!")
    try:
        torch.onnx.export(model=model, args=args, f=f, export_params=export_params, verbose=verbose, training=training,
                          input_names=input_names, output_names=output_names, do_constant_folding=do_constant_folding,
                          dynamic_axes=dynamic_axes)
        print("finish export onnx successfully!")
    except Exception as e:
        print("have some thing wrong in export onnx!")
        print(e)

相关文章

显卡天梯图2024最新版,显卡是电脑进行图形处理的重要设备,...
初始化电脑时出现问题怎么办,可以使用win系统的安装介质,连...
todesk远程开机怎么设置,两台电脑要在同一局域网内,然后需...
油猴谷歌插件怎么安装,可以通过谷歌应用商店进行安装,需要...
虚拟内存这个名词想必很多人都听说过,我们在使用电脑的时候...
win11本地账户怎么改名?win11很多操作都变了样,用户如果想要...