问题描述
我想以一种可以利用jax.jit
的方式从函数构造2D数组。
我通常使用numpy
进行此操作的方式是创建一个空数组,然后就地填充该数组。
xx = jnp.empty((num_a,num_b))
yy = jnp.empty((num_a,num_b))
zz = jnp.empty((num_a,num_b))
for ii_a in range(num_a):
for ii_b in range(num_b):
a = aa[ii_a,ii_b]
b = bb[ii_a,ii_b]
xyz = self.get_coord(a,b)
xx[ii_a,ii_b] = xyz[0]
yy[ii_a,ii_b] = xyz[1]
zz[ii_a,ii_b] = xyz[2]
为了在jax
中完成这项工作,我尝试使用jax.opt.index_update
。
xx = xx.at[ii_a,ii_b].set(xyz[0])
yy = yy.at[ii_a,ii_b].set(xyz[1])
zz = zz.at[ii_a,ii_b].set(xyz[2])
这运行时没有错误,但是在我尝试使用@jax.jit
装饰器时速度非常慢(至少比纯python / numpy版本慢一个数量级)。
解决方法
JAX有一个vmap
transform专为这种应用程序设计。
只要您的get_coords
函数与JAX兼容(即是一个没有副作用的纯函数),就可以在一行中完成此操作:
from jax import vmap
xx,yy,zz = vmap(vmap(get_coord))(aa,bb)
,
这可以通过使用jax.vmap
或jax.numpy.vectorize
函数来有效实现。
使用vectorize
的示例:
import jax.numpy as jnp
def get_coord(a,b):
return jnp.array([a,b,a+b])
f0 = jnp.vectorize(get_coord,signature='(),()->(i)')
f1 = jnp.vectorize(f0,excluded=(1,),signature='()->(i,j)')
xyz = f1(a,b)
vectorize
函数在幕后使用vmap
,因此它应该完全等同于:
f0 = jax.vmap(get_coord,(None,0))
f1 = jax.vmap(f0,(0,None))
使用vectorize
的优点是代码仍可以在标准numpy中运行。缺点是代码不够简洁,而且由于使用包装程序,可能会产生少量开销。