重新实现 bert 风格的池化器会抛出形状错误,就好像仍然需要长度维度一样

问题描述

我训练了一个现成的 Transformer()。

现在我想使用编码器来构建分类器。为此,我只想使用第一个令牌的输出(bert 样式 cls-token-result)并通过密集层运行它。

我做什么:

tl.Serial(encoder,tl.Fn('pooler',lambda x: (x[:,:])),tl.Dense(7))

形状: 编码器给了我形状(64、50、512)64 = batch_size, 50 = seq_len, 512 = model_dim

pooler 给我形状 (64,512),这符合预期。

密集层应该为每个批次成员采用 512 个维度并分类超过 7 个类别。但我猜 trax/jax 仍然希望它的长度为 seq_len (50)。

TypeError: dot_general requires contracting dimensions to have the same shape,got [512] and [50].

我想念什么?

完整回溯:

Traceback (most recent call last):
  File "mikado_classes.py",line 2054,in <module>
    app.run(main)
  File "/root/.local/lib/python3.7/site-packages/absl/app.py",line 300,in run
    _run_main(main,args)
  File "/root/.local/lib/python3.7/site-packages/absl/app.py",line 251,in _run_main
    sys.exit(main(argv))
  File "mikado_classes.py",line 1153,in main
    loop_neu.run(2)
  File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py",line 361,in run
    loss,optimizer_metrics = self._run_one_step(task_index,task_changed)
  File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py",line 483,in _run_one_step
    batch,rng,step=step,learning_rate=learning_rate
  File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py",line 134,in one_step
    (weights,self._slots),step,self._opt_params,batch,state,rng)
  File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py",line 173,in single_device_update_fn
    batch,weights,rng)
  File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py",line 549,in pure_fn
    self._caller,signature(x),trace) from None
jax._src.traceback_util.FilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py,line 865
  layer input shapes: (ShapeDtype{shape:(64,50),dtype:int32},ShapeDtype{shape:(64,1),dtype:int32})

  File [...]/trax/layers/combinators.py,line 88,in forward
    outputs,s = layer.pure_fn(inputs,w,s,use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/mikado_classes.py,line 1134
  layer input shapes: (ShapeDtype{shape:(64,use_cache=True)

LayerError: Exception passing through layer Dense_7 (in pure_fn):
  layer created in file [...]/mikado_classes.py,line 1133
  layer input shapes: ShapeDtype{shape:(64,512),dtype:float32}

  File [...]/trax/layers/assert_shape.py,line 122,in forward_wrapper
    y = forward(self,x,*args,**kwargs)

  File [...]/trax/layers/core.py,line 95,in forward
    return jnp.dot(x,w) + b  # Affine map.

  File [...]/_src/numpy/lax_numpy.py,line 3498,in dot
    return lax.dot_general(a,b,(contract_dims,batch_dims),precision)

  File [...]/_src/lax/lax.py,line 674,in dot_general
    preferred_element_type=preferred_element_type)

  File [...]/site-packages/jax/core.py,line 282,in bind
    out = top_trace.process_primitive(self,tracers,params)

  File [...]/jax/interpreters/ad.py,line 285,in process_primitive
    primal_out,tangent_out = jvp(primals_in,tangents_in,**params)

  File [...]/jax/interpreters/ad.py,line 458,in standard_jvp
    val_out = primitive.bind(*primals,**params)

  File [...]/site-packages/jax/core.py,params)

  File [...]/jax/interpreters/partial_eval.py,line 140,in process_primitive
    return self.default_process_primitive(primitive,line 147,in default_process_primitive
    return primitive.bind(*consts,line 1058,in process_primitive
    out_avals = primitive.abstract_eval(*avals,**params)

  File [...]/_src/lax/lax.py,line 1992,in standard_abstract_eval
    shapes,dtypes = shape_rule(*args,**kwargs),dtype_rule(*args,**kwargs)

  File [...]/_src/lax/lax.py,line 3090,in _dot_general_shape_rule
    raise TypeError(msg.format(lhs_contracting_shape,rhs_contracting_shape))

TypeError: dot_general requires contracting dimensions to have the same shape,got [512] and [50].

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred,unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "mikado_classes.py",rng)
  File "/root/.local/lib/python3.7/site-packages/jax/_src/traceback_util.py",line 139,in reraise_with_filtered_traceback
    return fun(*args,**kwargs)
  File "/root/.local/lib/python3.7/site-packages/jax/api.py",line 398,in f_jitted
    return cpp_jitted_f(context,line 295,in cache_miss
    donated_invars=donated_invars)
  File "/root/.local/lib/python3.7/site-packages/jax/core.py",line 1275,in bind
    return call_bind(self,fun,**params)
  File "/root/.local/lib/python3.7/site-packages/jax/core.py",line 1266,in call_bind
    outs = primitive.process(top_trace,params)
  File "/root/.local/lib/python3.7/site-packages/jax/core.py",line 1278,in process
    return trace.process_call(self,line 631,in process_call
    return primitive.impl(f,*tracers,**params)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py",line 581,in _xla_call_impl
    *unsafe_map(arg_spec,args))
  File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py",line 260,in memoized_fun
    ans = call(fun,*args)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py",line 656,in _xla_callable
    jaxpr,out_avals,consts = pe.trace_to_jaxpr_final(fun,abstract_args)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py",line 1216,in trace_to_jaxpr_final
    jaxpr,consts = trace_to_subjaxpr_dynamic(fun,main,in_avals)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py",line 1196,in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py",line 166,in call_wrapped
    ans = self.f(*args,**dict(self.params,**kwargs))
  File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py",line 810,in value_and_grad_f
    ans,vjp_py,aux = _vjp(f_partial,*dyn_args,has_aux=True)
  File "/root/.local/lib/python3.7/site-packages/jax/api.py",line 1918,in _vjp
    out_primal,out_vjp,aux = ad.vjp(flat_fun,primals_flat,has_aux=True)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/ad.py",line 116,in vjp
    out_primals,pvals,jaxpr,consts,aux = linearize(traceable,*primals,line 101,in linearize
    jaxpr,out_pvals,consts = pe.trace_to_jaxpr(jvpfun_flat,in_pvals)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py",line 506,in trace_to_jaxpr
    jaxpr,(out_pvals,env) = fun.call_wrapped(pvals)
  File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py",**kwargs))
  File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py",trace) from None
trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py,got [512] and [50].

解决方法

错误不在于架构。问题是:我的输入的形状不正确

目标应该是形状 (batch_size,) 但我发送了 (batch_size,1)。所以目标数组应该是,例如:

TypeError: only size-1 arrays can be converted to Python scalars

但我制作了

[1,5,99,2,1,3,8]

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...