与此 Python 函数等效的 JaxNumpy 兼容是什么?

问题描述

如何以兼容 JAX 的方式(例如,使用 jax.numpy)实现以下内容

def actions(state: tuple[int,...]) -> list[tuple[int,...]]:
    l = []
    iterables = [range(1,i+1) for i in state]
    ns = list(range(len(iterables)))
    for i,iterable in enumerate(iterables):
        for value in iterable:
            action = tuple(value if n == i else 0 for n in ns)
            l.append(action)
    return l

>>> state = (3,1,2)
>>> actions(state)
[(1,0),(2,(3,(0,1),2)]

解决方法

Jax 与 numpy 一样,无法有效地对 Python 容器类型(如列表和元组)进行操作,因此实际上并没有任何兼容 JAX 的方法来创建具有您在上面指定的确切签名的函数。

但是如果您认为返回值是一个二维数组,那么您可以根据 jnp.vstack 执行类似的操作:

from typing import Tuple
import jax.numpy as jnp
from jax import jit,partial

@partial(jit,static_argnums=0)
def actions(state: Tuple[int,...]) -> jnp.ndarray:
  return jnp.vstack([
    jnp.zeros((val,len(state)),int).at[:,i].set(jnp.arange(1,val + 1))
    for i,val in enumerate(state)])
>>> state = (3,1,2)
>>> actions(state)
DeviceArray([[1,0],[2,[3,[0,1],2]],dtype=int32)

注意,因为输出数组的大小取决于state的内容,所以state必须是一个静态量,所以输入元组是一个不错的选择。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...