从函数有效地填充数组

问题描述

我想以一种可以利用jax.jit的方式从函数构造2D数组。

我通常使用numpy进行此操作的方式是创建一个空数组,然后就地填充该数组。

xx = jnp.empty((num_a,num_b))
yy = jnp.empty((num_a,num_b))
zz = jnp.empty((num_a,num_b))

for ii_a in range(num_a):
    for ii_b in range(num_b):
        a = aa[ii_a,ii_b]
        b = bb[ii_a,ii_b]

        xyz = self.get_coord(a,b)

        xx[ii_a,ii_b] = xyz[0]
        yy[ii_a,ii_b] = xyz[1]
        zz[ii_a,ii_b] = xyz[2]

为了在jax中完成这项工作,我尝试使用jax.opt.index_update

        xx = xx.at[ii_a,ii_b].set(xyz[0])
        yy = yy.at[ii_a,ii_b].set(xyz[1])
        zz = zz.at[ii_a,ii_b].set(xyz[2])

这运行时没有错误,但是在我尝试使用@jax.jit装饰器时速度非常慢(至少比纯python / numpy版本慢一个数量级)。

使用jax函数填充多维数组的最佳方法是什么?

解决方法

JAX有一个vmap transform专为这种应用程序设计。

只要您的get_coords函数与JAX兼容(即是一个没有副作用的纯函数),就可以在一行中完成此操作:

from jax import vmap
xx,yy,zz = vmap(vmap(get_coord))(aa,bb)
,

这可以通过使用jax.vmapjax.numpy.vectorize函数来有效实现。

使用vectorize的示例:

import jax.numpy as jnp

def get_coord(a,b):
    return jnp.array([a,b,a+b])

f0 = jnp.vectorize(get_coord,signature='(),()->(i)')
f1 = jnp.vectorize(f0,excluded=(1,),signature='()->(i,j)')

xyz = f1(a,b)

vectorize函数在幕后使用vmap,因此它应该完全等同于:

f0 = jax.vmap(get_coord,(None,0))
f1 = jax.vmap(f0,(0,None)) 

使用vectorize的优点是代码仍可以在标准numpy中运行。缺点是代码不够简洁,而且由于使用包装程序,可能会产生少量开销。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...