因为工作部署需要,经常需要将自己的torch模型转化为onnx格式进行部署,写了一个转化脚本
import torch
def model2onnx(model, arg, onnx_name, dynamic=False):
r'''
model:torch.module已经加载好的模型
arg:输入参数
onnx_name:输出onnx模型的名称
dynamic:是否动态输入
'''
model = model
args = arg
f = onnx_name
export_params = True
verbose = False
training = False
input_names = ['inputs']
output_names = ['outputs']
do_constant_folding = True
if dynamic == False:
dynamic_axes = None
else:
'''
实际应用的时候输入图片的尺寸是不固定的,而且可能一次输入多种图片一起处理。我们可以通过指定dynamic_axes参数来导出动态输入的模型。
dynamic_axes的参数是一个字典类型,字典的key就是输入或者输出的名字,对应key的value可以是一个字典或者列表,
指定了输入或者输出的index以及对应的名字。比如想要让输入的index为0的维度表示动态的batch_size那么就指定{0: 'batch_size'}。
同样的方法可以指定宽高所在的维度输出成动态的。
'''
dynamic_axes = {
'inputs': {0: 'batch_size', 2: 'in_width', 3: 'in_height'},
'outputs': {0: 'batch_size', 2: 'out_width', 3: 'out_height'}
}
print("start export onnx!")
try:
torch.onnx.export(model=model, args=args, f=f, export_params=export_params, verbose=verbose, training=training,
input_names=input_names, output_names=output_names, do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes)
print("finish export onnx successfully!")
except Exception as e:
print("have some thing wrong in export onnx!")
print(e)