问题描述
我一直在尝试将pytorch模型转换为coreML格式,但是目前不支持其中的一层replication_pad2d。因此,我试图使用注册运算符装饰器@register_torch_op
来实现它,以重新实现coremltools.converters的层,但是,我一直在努力理解输入类型以能够当前实现该功能。我知道了,这是从pytorch粗略翻译过来的实现,但是不起作用
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
@register_torch_op
def replication_pad2d(context,node):
inputs = _get_inputs(context,node)
x = inputs[0]
a = len(x)
L_list,R_list = [],[]
U_list,D_list = [],[]
for i in range(a):#i:0,1
l = x[:,:,(a-i):(a-i+1)]
L_list.append(l)
r = x[:,(i-a-1):(i-a)]
R_list.append(r)
L_list.append(x)
x = mb.concat(L_list+R_list[::-1],axis=3,name=node.name)
for i in range(a):
u = x[:,(a-i):(a-i+1),:]
U_list.append(u)
d = x[:,(i-a-1):(i-a),:]
D_list.append(d)
U_list.append(x)
x = mb.concat(U_list+D_list[::-1],name=node.name)
context.add(x)
但出现以下错误
<ipython-input-12-cf14ed84cb93> in replication_pad2d(context,node)
59 inputs = _get_inputs(context,node)
60 x = inputs[0]
---> 61 a = len(x)
62 L_list,[]
63 U_list,[]
TypeError: object of type 'Var' has no len()
如果有人可以帮助我更好地理解这一点,特别是输入类型的节点和上下文,那将是很棒的
解决方法
我认为您可以将现有的填充层用作:
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
@register_torch_op(torch_alias=["replication_pad2d"])
def HackedReplication_pad2d(context,node):
inputs = _get_inputs(context,node)
x = inputs[0]
pad = inputs[1].val
x_pad = mb.pad(x=x,pad=[pad[2],pad[3],pad[0],pad[1]],mode='replicate')
context.add(x_pad,node.name)
填充操作的文档不是那么好,因此填充参数的排序是一个猜谜游戏。