scipy stats zmap函数的替代方法

问题描述

zmap函数的scipy stats模块是否有替代方案?我目前正在使用它来获取两个非常大的数组的zmap分数,这需要花费一些时间。

是否有任何库或替代品可以提高其性能?甚至获得zmap函数功能的另一个方法

您的想法和意见将不胜感激!

这是我下面的最小可重复代码

from scipy import stats
import numpy as np

FeatureData = np.random.rand(483,1)
goodData = np.random.rand(4640,483)
Featurenorm= stats.zmap(FeatureData,goodData)

这是scipy stats.zmap在后台执行的操作:

def zmap(scores,compare,axis=0,ddof=0):
    scores,compare = map(np.asanyarray,[scores,compare])
    mns = compare.mean(axis=axis,keepdims=True)
    sstd = compare.std(axis=axis,ddof=ddof,keepdims=True)
    return (scores - mns) / sstd

关于如何针对我的用例进行优化的任何想法?我可以使用numba或JAX之类的库来进一步增强它吗?

解决方法

幸运的是,zmap代码非常简单。但是,numpy的开销来自必须实例化中间数组的事实。如果您使用数字编译器(例如numbajax中的数字编译器),它可以融合这些操作并以较少的开销进行计算。

不幸的是,numba不支持meanstd的可选参数,因此让我们看一下JAX。作为参考,以下是在Google Colab CPU运行时上计算的scipy和该函数的原始numpy版本的基准:

import numpy as np
from scipy import stats

FeatureData = np.random.rand(483,1)
goodData = np.random.rand(4640,483)

%timeit stats.zmap(FeatureData,goodData)
# 100 loops,best of 3: 13.9 ms per loop

def np_zmap(scores,compare,axis=0,ddof=0):
    scores,compare = map(np.asanyarray,[scores,compare])
    mns = compare.mean(axis=axis,keepdims=True)
    sstd = compare.std(axis=axis,ddof=ddof,keepdims=True)
    return (scores - mns) / sstd

%timeit np_zmap(FeatureData,best of 3: 13.8 ms per loop

这是在JAX中执行的等效代码,包括急切模式和JIT编译:

import jax.numpy as jnp
from jax import jit

def jnp_zmap(scores,compare = map(jnp.asarray,keepdims=True)
    return (scores - mns) / sstd

jit_jnp_zmap = jit(jnp_zmap)

FeatureData = jnp.array(FeatureData)
goodData = jnp.array(goodData)
%timeit jnp_zmap(FeatureData,goodData).block_until_ready()
# 100 loops,best of 3: 8.59 ms per loop

jit_jnp_zmap(FeatureData,goodData)  # trigger compilation
%timeit jit_jnp_zmap(FeatureData,best of 3: 2.78 ms per loop

JIT编译的版本比scipy或numpy代码快大约5倍。在Colab T4 GPU运行时上,编译版本的系数是10:

%timeit jit_jnp_zmap(FeatureData,goodData).block_until_ready()
1000 loops,best of 3: 286 µs per loop

如果这种操作是分析的瓶颈,那么像JAX这样的编译器可能是个不错的选择。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...