JAX中的条件更新?

问题描述

在autograd / numpy中,我可以这样做:

q[q<0] = 0.0

我如何在JAX中做同样的事情?

我尝试了import numpy as onp并使用它来创建数组,但这似乎不起作用。

解决方法

JAX数组是不可变的,因此就地索引分配语句无法工作。而是,jax提供了jax.ops子模块,该子模块提供了创建数组更新版本的功能。

这是一个numpy索引分配和等效的JAX索引更新的示例:

import numpy as np
q = np.arange(-5,5)
q[q < 0] = 0
print(q)
# [0 0 0 0 0 0 1 2 3 4]

import jax.numpy as jnp
q = jnp.arange(-5,5)
q = q.at[q < 0].set(0)  # NB: this does not modify the original array,# but rather returns a modified copy.
print(q)
# [0 0 0 0 0 0 1 2 3 4]

请注意,在逐个操作模式下,JAX版本会创建该数组的多个副本。但是,当在JIT编译中使用XLA时,XLA通常会融合此类操作并避免复制数据。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...