trax tl.Relu 和 tl.ShiftRight 层嵌套在 Serial Combinator 输出预期产出

问题描述

我正在尝试构建一个注意力模型,但认情况下 Relu 和 ShiftRight 层嵌套在串行组合器中。 这进一步使我在训练中出错。

layer_block = tl.Serial(
    tl.Relu(),tl.Layernorm(),)

x = np.array([[-2,-1,1,2],[-20,-10,10,20]]).astype(np.float32) 

layer_block.init(shapes.signature(x)) y = layer_block(x)

print(f'layer_block: {layer_block}')

输出

layer_block: Serial[
  Serial[
    Relu
  ]
  Layernorm
]

预期产出

layer_block: Serial[
  Relu
  Layernorm
]

tl.ShiftRight() 出现同样的问题

以上代码摘自官方文档Example 5

提前致谢

解决方法

我找不到上述问题的确切解决方案,但您可以使用 tl.Fn() 创建自定义函数并在其中添加 ReluShiftRight 函数代码。

def _zero_pad(x,pad,axis):
    """Helper for jnp.pad with 0s for single-axis case."""
    pad_widths = [(0,0)] * len(x.shape)
    pad_widths[axis] = pad  # Padding on axis.
    
    return jnp.pad(x,pad_widths,mode='constant')


def f(x):
    if mode == 'predict':
        return x
    padded = _zero_pad(x,(n_positions,0),1)
    return padded[:,:-n_positions]

# set ShiftRight parameters as global 
n_positions = 1
mode='train'

layer_block = tl.Serial(
    tl.Fn('Relu',lambda x: jnp.where(x <= 0,jnp.zeros_like(x),x)),tl.LayerNorm(),tl.Fn(f'ShiftRight({n_positions})',f)
)


x = np.array([[-2,-1,1,2],[-20,-10,10,20]]).astype(np.float32)
layer_block.init(shapes.signature(x))
y = layer_block(x)


print(f'layer_block: {layer_block}')

输出

layer_block: Serial[
  Relu
  LayerNorm
  ShiftRight(1)
]