问题描述
在autograd / numpy中,我可以这样做:
q[q<0] = 0.0
我如何在JAX中做同样的事情?
我尝试了import numpy as onp
并使用它来创建数组,但这似乎不起作用。
解决方法
JAX数组是不可变的,因此就地索引分配语句无法工作。而是,jax提供了jax.ops
子模块,该子模块提供了创建数组更新版本的功能。
这是一个numpy索引分配和等效的JAX索引更新的示例:
import numpy as np
q = np.arange(-5,5)
q[q < 0] = 0
print(q)
# [0 0 0 0 0 0 1 2 3 4]
import jax.numpy as jnp
q = jnp.arange(-5,5)
q = q.at[q < 0].set(0) # NB: this does not modify the original array,# but rather returns a modified copy.
print(q)
# [0 0 0 0 0 0 1 2 3 4]
请注意,在逐个操作模式下,JAX版本会创建该数组的多个副本。但是,当在JIT编译中使用XLA时,XLA通常会融合此类操作并避免复制数据。