如何以torch_op的形式在coremltools转换器中实现Replication_pad2d层

问题描述

我一直在尝试将pytorch模型转换为coreML格式,但是目前不支持其中的一层replication_pad2d。因此,我试图使用注册运算符装饰器@register_torch_op来实现它,以重新实现coremltools.converters的层,但是,我一直在努力理解输入类型以能够当前实现该功能。我知道了,这是从pytorch粗略翻译过来的实现,但是不起作用

from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.frontend.torch.ops import _get_inputs

@register_torch_op
def replication_pad2d(context,node):
  inputs = _get_inputs(context,node)
  x = inputs[0]
  a = len(x)
  L_list,R_list = [],[]
  U_list,D_list = [],[]
  for i in range(a):#i:0,1
    l = x[:,:,(a-i):(a-i+1)]
    L_list.append(l)
    r = x[:,(i-a-1):(i-a)]
    R_list.append(r)
  L_list.append(x)
  x = mb.concat(L_list+R_list[::-1],axis=3,name=node.name)
  for i in range(a):
    u = x[:,(a-i):(a-i+1),:]
    U_list.append(u)
    d = x[:,(i-a-1):(i-a),:]
    D_list.append(d)
  U_list.append(x)
  x = mb.concat(U_list+D_list[::-1],name=node.name)
  context.add(x)

但出现以下错误

<ipython-input-12-cf14ed84cb93> in replication_pad2d(context,node)
     59   inputs = _get_inputs(context,node)
     60   x = inputs[0]
---> 61   a = len(x)
     62   L_list,[]
     63   U_list,[]

TypeError: object of type 'Var' has no len()

如果有人可以帮助我更好地理解这一点,特别是输入类型的节点和上下文,那将是很棒的

解决方法

我认为您可以将现有的填充层用作:

from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.frontend.torch.ops import _get_inputs

@register_torch_op(torch_alias=["replication_pad2d"])
def HackedReplication_pad2d(context,node):
    inputs = _get_inputs(context,node)
    x = inputs[0]
    pad = inputs[1].val
    x_pad = mb.pad(x=x,pad=[pad[2],pad[3],pad[0],pad[1]],mode='replicate')
    context.add(x_pad,node.name)

填充操作的文档不是那么好,因此填充参数的排序是一个猜谜游戏。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...