JAX 仅在 jit

问题描述

我正在使用 JAX,我想执行类似的操作

@jax.jit
def fun(x,index):
    x[:index] = other_fun(x[:index])
    return x

这不能在 jit 下执行。有没有办法用 jax.opsjax.lax 做到这一点? 我想过使用 jax.ops.index_update(x,idx,y),但我无法找到一种计算 y方法,而不会再次遇到同样的问题。

解决方法

您的实施似乎存在两个问题。首先,切片产生动态形状的数组(不允许在即时代码中)。其次,与 numpy 数组不同,JAX 数组是不可变的(即数组的内容不能改变)。

您可以通过组合 static_argnumsjax.lax.dynamic_update_slice 来克服这两个问题。下面是一个例子:

def other_fun(x):
    return x + 1

@jax.partial(jax.jit,static_argnums=(1,))
def fun(x,index):
    update = other_fun(x[:index])
    return jax.lax.dynamic_update_slice(x,update,(0,))

x = jnp.arange(5)
print(fun(x,3))  # prints [1 2 3 3 4]

本质上,上面的例子使用 static_argnums 来指示函数应该为不同的 index 值重新编译,jax.lax.dynamic_update_slice 创建一个 x 的副本,并在 { {1}}。

,

@rvinas 的 previous answer 使用 dynamic_slice 如果您的索引是静态的,则效果很好,但您也可以使用 jnp.where 使用动态索引来完成此操作。例如:

import jax
import jax.numpy as jnp

def other_fun(x):
    return x + 1

@jax.jit
def fun(x,index):
  mask = jnp.arange(x.shape[0]) < index
  return jnp.where(mask,other_fun(x),x)

x = jnp.arange(5)
print(fun(x,3))
# [1 2 3 3 4]

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...