问题描述
如何以兼容 JAX 的方式(例如,使用 jax.numpy
)实现以下内容?
def actions(state: tuple[int,...]) -> list[tuple[int,...]]:
l = []
iterables = [range(1,i+1) for i in state]
ns = list(range(len(iterables)))
for i,iterable in enumerate(iterables):
for value in iterable:
action = tuple(value if n == i else 0 for n in ns)
l.append(action)
return l
>>> state = (3,1,2)
>>> actions(state)
[(1,0),(2,(3,(0,1),2)]
解决方法
Jax 与 numpy 一样,无法有效地对 Python 容器类型(如列表和元组)进行操作,因此实际上并没有任何兼容 JAX 的方法来创建具有您在上面指定的确切签名的函数。
但是如果您认为返回值是一个二维数组,那么您可以根据 jnp.vstack
执行类似的操作:
from typing import Tuple
import jax.numpy as jnp
from jax import jit,partial
@partial(jit,static_argnums=0)
def actions(state: Tuple[int,...]) -> jnp.ndarray:
return jnp.vstack([
jnp.zeros((val,len(state)),int).at[:,i].set(jnp.arange(1,val + 1))
for i,val in enumerate(state)])
>>> state = (3,1,2)
>>> actions(state)
DeviceArray([[1,0],[2,[3,[0,1],2]],dtype=int32)
注意,因为输出数组的大小取决于state
的内容,所以state
必须是一个静态量,所以输入元组是一个不错的选择。