问题描述
我训练了一个现成的 Transformer()。
现在我想使用编码器来构建分类器。为此,我只想使用第一个令牌的输出(bert 样式 cls-token-result)并通过密集层运行它。
我做什么:
tl.Serial(encoder,tl.Fn('pooler',lambda x: (x[:,:])),tl.Dense(7))
形状:
编码器给了我形状(64、50、512)
和
64 = batch_size,
50 = seq_len,
512 = model_dim
pooler 给我形状 (64,512),这符合预期。
密集层应该为每个批次成员采用 512 个维度并分类超过 7 个类别。但我猜 trax/jax 仍然希望它的长度为 seq_len (50)。
TypeError: dot_general requires contracting dimensions to have the same shape,got [512] and [50].
我想念什么?
完整回溯:
Traceback (most recent call last):
File "mikado_classes.py",line 2054,in <module>
app.run(main)
File "/root/.local/lib/python3.7/site-packages/absl/app.py",line 300,in run
_run_main(main,args)
File "/root/.local/lib/python3.7/site-packages/absl/app.py",line 251,in _run_main
sys.exit(main(argv))
File "mikado_classes.py",line 1153,in main
loop_neu.run(2)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py",line 361,in run
loss,optimizer_metrics = self._run_one_step(task_index,task_changed)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py",line 483,in _run_one_step
batch,rng,step=step,learning_rate=learning_rate
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py",line 134,in one_step
(weights,self._slots),step,self._opt_params,batch,state,rng)
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py",line 173,in single_device_update_fn
batch,weights,rng)
File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py",line 549,in pure_fn
self._caller,signature(x),trace) from None
jax._src.traceback_util.FilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py,line 865
layer input shapes: (ShapeDtype{shape:(64,50),dtype:int32},ShapeDtype{shape:(64,1),dtype:int32})
File [...]/trax/layers/combinators.py,line 88,in forward
outputs,s = layer.pure_fn(inputs,w,s,use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/mikado_classes.py,line 1134
layer input shapes: (ShapeDtype{shape:(64,use_cache=True)
LayerError: Exception passing through layer Dense_7 (in pure_fn):
layer created in file [...]/mikado_classes.py,line 1133
layer input shapes: ShapeDtype{shape:(64,512),dtype:float32}
File [...]/trax/layers/assert_shape.py,line 122,in forward_wrapper
y = forward(self,x,*args,**kwargs)
File [...]/trax/layers/core.py,line 95,in forward
return jnp.dot(x,w) + b # Affine map.
File [...]/_src/numpy/lax_numpy.py,line 3498,in dot
return lax.dot_general(a,b,(contract_dims,batch_dims),precision)
File [...]/_src/lax/lax.py,line 674,in dot_general
preferred_element_type=preferred_element_type)
File [...]/site-packages/jax/core.py,line 282,in bind
out = top_trace.process_primitive(self,tracers,params)
File [...]/jax/interpreters/ad.py,line 285,in process_primitive
primal_out,tangent_out = jvp(primals_in,tangents_in,**params)
File [...]/jax/interpreters/ad.py,line 458,in standard_jvp
val_out = primitive.bind(*primals,**params)
File [...]/site-packages/jax/core.py,params)
File [...]/jax/interpreters/partial_eval.py,line 140,in process_primitive
return self.default_process_primitive(primitive,line 147,in default_process_primitive
return primitive.bind(*consts,line 1058,in process_primitive
out_avals = primitive.abstract_eval(*avals,**params)
File [...]/_src/lax/lax.py,line 1992,in standard_abstract_eval
shapes,dtypes = shape_rule(*args,**kwargs),dtype_rule(*args,**kwargs)
File [...]/_src/lax/lax.py,line 3090,in _dot_general_shape_rule
raise TypeError(msg.format(lhs_contracting_shape,rhs_contracting_shape))
TypeError: dot_general requires contracting dimensions to have the same shape,got [512] and [50].
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred,unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "mikado_classes.py",rng)
File "/root/.local/lib/python3.7/site-packages/jax/_src/traceback_util.py",line 139,in reraise_with_filtered_traceback
return fun(*args,**kwargs)
File "/root/.local/lib/python3.7/site-packages/jax/api.py",line 398,in f_jitted
return cpp_jitted_f(context,line 295,in cache_miss
donated_invars=donated_invars)
File "/root/.local/lib/python3.7/site-packages/jax/core.py",line 1275,in bind
return call_bind(self,fun,**params)
File "/root/.local/lib/python3.7/site-packages/jax/core.py",line 1266,in call_bind
outs = primitive.process(top_trace,params)
File "/root/.local/lib/python3.7/site-packages/jax/core.py",line 1278,in process
return trace.process_call(self,line 631,in process_call
return primitive.impl(f,*tracers,**params)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py",line 581,in _xla_call_impl
*unsafe_map(arg_spec,args))
File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py",line 260,in memoized_fun
ans = call(fun,*args)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py",line 656,in _xla_callable
jaxpr,out_avals,consts = pe.trace_to_jaxpr_final(fun,abstract_args)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py",line 1216,in trace_to_jaxpr_final
jaxpr,consts = trace_to_subjaxpr_dynamic(fun,main,in_avals)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py",line 1196,in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py",line 166,in call_wrapped
ans = self.f(*args,**dict(self.params,**kwargs))
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py",line 810,in value_and_grad_f
ans,vjp_py,aux = _vjp(f_partial,*dyn_args,has_aux=True)
File "/root/.local/lib/python3.7/site-packages/jax/api.py",line 1918,in _vjp
out_primal,out_vjp,aux = ad.vjp(flat_fun,primals_flat,has_aux=True)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/ad.py",line 116,in vjp
out_primals,pvals,jaxpr,consts,aux = linearize(traceable,*primals,line 101,in linearize
jaxpr,out_pvals,consts = pe.trace_to_jaxpr(jvpfun_flat,in_pvals)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py",line 506,in trace_to_jaxpr
jaxpr,(out_pvals,env) = fun.call_wrapped(pvals)
File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py",**kwargs))
File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py",trace) from None
trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py,got [512] and [50].
解决方法
错误不在于架构。问题是:我的输入的形状不正确。
目标应该是形状 (batch_size,) 但我发送了 (batch_size,1)。所以目标数组应该是,例如:
TypeError: only size-1 arrays can be converted to Python scalars
但我制作了
[1,5,99,2,1,3,8]