Jax中的vmap ops.index_update

问题描述

我在下面有以下代码,它使用一个简单的for循环。我只是想知道是否有一种vmap的方法?这是原始代码

import numpy as np 
import jax.numpy as jnp
import jax.scipy.signal as jscp
from scipy import signal
import jax

data = np.random.rand(192,334)

a = [1,-1.086740193996892,0.649914553946275,-0.124948974636730]
b = [0.054778173164082,0.164334519492245,0.054778173164082]
impulse = signal.lfilter(b,a,[1] + [0]*99) 
impulse_20 = impulse[:20]
impulse_20 = jnp.asarray(impulse_20)

@jax.jit
def filter_jax(y):
    for ind in range(0,len(y)):
      y = jax.ops.index_update(y,jax.ops.index[:,ind],jscp.convolve(impulse_20,y[:,ind])[:-19])
    return y

jnpData = jnp.asarray(data)

%timeit filter_jax(jnpData).block_until_ready()

这是我尝试使用vmap的尝试:

def paraUpdate(y,ind):
    return jax.ops.index_update(y,ind])[:-19])

@jax.jit
def filter_jax2(y):
  ranger = range(0,len(y))
  return jax.vmap(paraUpdate,y)(ranger)

但是我收到以下错误

TypeError:vmap in_axes必须是int,None或(嵌套的)容器 这些类型的叶子,但得到了 通过跟踪到

我有点困惑,因为范围是int类型,所以我不太确定发生了什么。

最后,我试图尽可能优化此小片段,以使时间最短。

解决方法

jax.vmap可以表示其中跨输入的多个轴独立应用单个操作的功能。您的功能有所不同:您将一个操作迭代地应用于单个输入。

幸运的是,JAX提供了lax.scan可以处理这种情况。该实现将如下所示:

from jax import lax

def paraUpdate(y,ind):
    return jax.ops.index_update(y,jax.ops.index[:,ind],jscp.convolve(impulse_20,y[:,ind])[:-19]),ind

@jax.jit
def filter_jax2(y):
  ranger = jnp.arange(len(y))
  return lax.scan(paraUpdate,y,ranger)[0]

print(np.allclose(filter_jax(jnpData),filter_jax2(jnpData)))
# True

%timeit filter_jax(jnpData).block_until_ready()
# 10 loops,best of 3: 28.6 ms per loop

%timeit filter_jax2(jnpData).block_until_ready()
# 1000 loops,best of 3: 519 µs per loop

如果更改算法,以便将操作应用于数组中的列而不是前 N 列,则可以用{{ 1}}:

vmap

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...