为什么在JAX和numpy中此功能较慢?

问题描述

我具有以下numpy函数,如下所示,我正在尝试使用JAX进行优化,但出于某种原因,它的速度较慢。

有人可以指出我可以做些什么来提高性能吗?我怀疑这与Cg_new的列表理解有关,但是将其分开并不会在JAX中产生任何进一步的性能提升。

import numpy as np 

def testFunction_numpy(C,Mi,C_new,Mi_new):
    Wg_new = np.zeros((len(Mi_new[:,0]),len(Mi[0])))
    Cg_new = np.zeros((1,len(Mi[0])))
    invertCsensor_new = np.linalg.inv(C_new)

    Wg_new = np.dot(invertCsensor_new,Mi_new)
    Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))),(Wg_new[:,m])) for m in range(0,len(Mi[0]))] 

    return C_new,Mi_new,Wg_new,Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

%timeit testFunction_numpy(C,Mi_new)
#1000 loops,best of 3: 1.73 ms per loop

相当于JAX:

import jax.numpy as jnp
import numpy as np
import jax

def testFunction_JAX(C,Mi_new):
    Wg_new = jnp.zeros((len(Mi_new[:,len(Mi[0])))
    Cg_new = jnp.zeros((1,len(Mi[0])))
    invertCsensor_new = jnp.linalg.inv(C_new)

    Wg_new = jnp.dot(invertCsensor_new,Mi_new)
    Cg_new = [jnp.dot(((-0.5*(Mi_new[:,8)

C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)

jitter = jax.jit(testFunction_JAX) 

%timeit jitter(C,Mi_new)
#1 loop,best of 3: 4.96 ms per loop

解决方法

当JAX jit编译遇到Python控制流(包括列表推导)时,它将有效地拉平循环并逐步执行整个操作序列。这可能会导致jit编译时间变慢以及代码不理想。幸运的是,您的函数中的列表理解很容易用本地numpy广播表示。此外,您还可以进行其他两项改进:

  • 在计算它们之前,无需转发声明Wg_newCg_new
  • 在计算dot(inv(A),B)时,使用np.linalg.solve而不是显式计算逆函数会更加高效和精确。

对numpy和JAX版本进行了这三项改进,结果如下:

def testFunction_numpy_v2(C,Mi,C_new,Mi_new):
    Wg_new = np.linalg.solve(C_new,Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return C_new,Mi_new,Wg_new,Cg_new

@jax.jit
def testFunction_JAX_v2(C,Mi_new):
    Wg_new = jnp.linalg.solve(C_new,Cg_new

%timeit testFunction_numpy_v2(C,Mi_new)
# 1000 loops,best of 3: 1.11 ms per loop
%timeit testFunction_JAX_v2(C_jax,Mi_jax,C_new_jax,Mi_new_jax)
# 1000 loops,best of 3: 1.35 ms per loop

由于改进了实现,这两个函数的速度都比以前快了一点。但是,您会注意到,JAX在这里仍然比numpy慢。这在某种程度上是可以预料的,因为对于这种简单程度的功能,JAX和numpy都有效地生成了在CPU体系结构上执行的相同简短系列的BLAS和LAPACK调用。 numpy的引用实现根本没有太多改进的空间,而且使用如此小的数组,JAX的开销显而易见。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...