问题描述
我正在开发 T5X(FLAX 代码)的最新版本。
虽然在 Colab(带有一个 GPU 的标准笔记本)上运行它没有问题,但我在我的一台机器上运行代码时遇到了一些困难。具体来说,我的配置是 Ubuntu 18.04、CUDA 11、CuDNN 7,带有两个 GPU(TITAN Xp,每个 12GB)。
JAX 可以识别 GPU,并且代码可以正确运行到第一个 pmap
。
错误如下:
RuntimeError: Unimplemented: Requested AllReduce not implemented on GPU; replica_count: 2; operand_count: 131; IsCrossReplicaAllReduce: 1; Nccl support: 1; first operand array element-type: BF16
通过使用 float32 训练成功。出现与在多 GPU 上使用 jnp.float16 相关的错误。你能帮我吗?
这里是完整的回溯:
Traceback (most recent call last):
File "ft_t5_small_super_glue.py",line 84,in <module>
train(model_dir='t5x_data',config=fine_tuning_cfg)
File "/workspace/linear_t5x/src/t5x_utils/training.py",line 525,in train
jnp.array(0,dtype=jnp.int32),1)
File "/opt/conda/lib/python3.6/site-packages/jax/api.py",line 1564,in f_pmapped
global_arg_shapes=tuple(global_arg_shapes_flat))
File "/opt/conda/lib/python3.6/site-packages/jax/core.py",line 1262,in bind
return call_bind(self,fun,*args,**params)
File "/opt/conda/lib/python3.6/site-packages/jax/core.py",line 1226,in call_bind
outs = primitive.process(top_trace,tracers,params)
File "/opt/conda/lib/python3.6/site-packages/jax/core.py",line 1265,in process
return trace.process_map(self,line 598,in process_call
return primitive.impl(f,*tracers,**params)
File "/opt/conda/lib/python3.6/site-packages/jax/interpreters/pxla.py",line 635,in xla_pmap_impl
*abstract_args)
File "/opt/conda/lib/python3.6/site-packages/jax/linear_util.py",line 251,in memoized_fun
ans = call(fun,*args)
File "/opt/conda/lib/python3.6/site-packages/jax/interpreters/pxla.py",line 892,in parallel_callable
compiled = xla.backend_compile(backend,built,compile_options)
File "/opt/conda/lib/python3.6/site-packages/jax/interpreters/xla.py",line 349,in backend_compile
return backend.compile(built_c,compile_options=options)
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)