切片jax.numpy数组时性能下降 JAX没有@jitPyTorch

问题描述

在尝试对大型阵列进行SVD​​压缩时,我遇到了一些Jax无法理解的行为。这是示例代码

@jit 
def jax_compress(L):
    U,S,_ = jsc.linalg.svd(L,full_matrices = False,lapack_driver = 'gesvd',check_finite=False,overwrite_a=True)

    maxS=jnp.max(S)
    chi = jnp.sum(S/maxS>1E-1)

    return chi,jnp.asarray(U)

考虑到此代码段,Jax / jit的性能大大超过了SciPy,但最终我想减小U的维数,方法是将U包裹在函数中:

def jax_process(A):

    chi,U = jax_compress(A)
    
    return U[:,0:chi]

此步骤在计算时间方面的成本令人难以置信,比同等的SciPy还要高,如下面的比较所示:

benchmark of jax and scipy

sc_compresssc_process是上述jax代码的SciPy等效项。如您所见,在SciPy中切片数组几乎不花费任何费用,但是将其应用于hit函数输出时则非常昂贵。有人对此行为有见识吗?

解决方法

我对 JAX 和 PyTorch 之间的切片速度进行了类似的比较。 dynamic_slice 比常规切片快得多,但仍然比 Torch 中的等价物慢得多。由于我是 JAX 的新手,我不确定原因是什么,但这可能与复制与引用有关,因为 JAX 数组是不可变的。

JAX(没有@jit)

key = random.PRNGKey(0)
j = random.normal(key,(32,2,1024,3))
%timeit j[...,100:600,:].block_until_ready()
%timeit dynamic_slice(j,[0,100,0],[32,500,3]).block_until_ready()
2.78 ms ± 198 µs per loop (mean ± std. dev. of 7 runs,100 loops each)
993 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs,1000 loops each)

PyTorch

t = torch.randn((32,3)).cuda()

%%timeit 
t[...,:]
torch.cuda.synchronize()
7.63 µs ± 22.7 ns per loop (mean ± std. dev. of 7 runs,100000 loops each)
,

我不是Jax专家,我不确定它是如何工作的,但是我运行了该代码片段并进行了查看。

我非常确定jax_compress中的Jax函数(或jit装饰器的效果)是惰性计算的,因此只有当您“看”输出矩阵时,它们才执行完整的计算。计算结束并实际要求具体数字(很像python生成器做的事情,以及功能语言如Haskell)。

我认为,您最后要进行的数组切片基本上是这种“询问具体矩阵”的形式。

您可以通过在访问数组元素后单独计时jax_compress函数来检查此问题:

ti = time.time()
X,U = jax_compress(A)
# very fast
print(f"Compession takes {time.time() - ti} seconds when not peeking")

ti = time.time()
X,U = jax_compress(A)
# much slower
print(U[0][0])
print(f"Compession takes {time.time() - ti} seconds when peeking")

一种解决方案可能是使用lax.dynamic_slicelax.dynamic_update_slice,我相信在jax.numpy.lax_numpy中有一个Jax实现。但是,根据您的硬件,我的直觉是您不会发现太多的提速,因为SVD的科学实现还是无论如何(对于单个CPU机器)都是围绕着高度优化和编译的Fortran代码的包装。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...