BF16 和多个 GPU 的 JAX/FLAX 运行时错误

问题描述

我正在开发 T5X(FLAX 代码)的最新版本。 虽然在 Colab(带有一个 GPU 的标准笔记本)上运行它没有问题,但我在我的一台机器上运行代码时遇到了一些困难。具体来说,我的配置是 Ubuntu 18.04、CUDA 11、CuDNN 7,带有两个 GPU(TITAN Xp,每个 12GB)。 JAX 可以识别 GPU,并且代码可以正确运行到第一个 pmap

错误如下: RuntimeError: Unimplemented: Requested AllReduce not implemented on GPU; replica_count: 2; operand_count: 131; IsCrossReplicaAllReduce: 1; Nccl support: 1; first operand array element-type: BF16

通过使用 float32 训练成功。出现与在多 GPU 上使用 jnp.float16 相关的错误。你能帮我吗?

这里是完整的回溯:

Traceback (most recent call last):
  File "ft_t5_small_super_glue.py",line 84,in <module>
    train(model_dir='t5x_data',config=fine_tuning_cfg)
  File "/workspace/linear_t5x/src/t5x_utils/training.py",line 525,in train
    jnp.array(0,dtype=jnp.int32),1)
  File "/opt/conda/lib/python3.6/site-packages/jax/api.py",line 1564,in f_pmapped
    global_arg_shapes=tuple(global_arg_shapes_flat))
  File "/opt/conda/lib/python3.6/site-packages/jax/core.py",line 1262,in bind
    return call_bind(self,fun,*args,**params)
  File "/opt/conda/lib/python3.6/site-packages/jax/core.py",line 1226,in call_bind
    outs = primitive.process(top_trace,tracers,params)
  File "/opt/conda/lib/python3.6/site-packages/jax/core.py",line 1265,in process
    return trace.process_map(self,line 598,in process_call
    return primitive.impl(f,*tracers,**params)
  File "/opt/conda/lib/python3.6/site-packages/jax/interpreters/pxla.py",line 635,in xla_pmap_impl
    *abstract_args)
  File "/opt/conda/lib/python3.6/site-packages/jax/linear_util.py",line 251,in memoized_fun
    ans = call(fun,*args)
  File "/opt/conda/lib/python3.6/site-packages/jax/interpreters/pxla.py",line 892,in parallel_callable
    compiled = xla.backend_compile(backend,built,compile_options)
  File "/opt/conda/lib/python3.6/site-packages/jax/interpreters/xla.py",line 349,in backend_compile
    return backend.compile(built_c,compile_options=options)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...