使用 vmap 时,Jax 不支持不可散列的静态参数

问题描述

这与this question有关。经过一些工作,我设法将其更改为最后一个错误代码现在看起来像这样。

import jax.numpy as jnp
from jax import grad,jit,value_and_grad
from jax import vmap,pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp


def p_tau(z,tau,alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau,0) ** (1 / (alpha - 1))


def get_tau(tau,tau_max,tau_min,z_value):
    return lax.cond(z_value < 1,lambda _: (tau,tau_min),lambda _: (tau_max,tau),operand=None
                    )


def body(kwargs,x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z,alpha).sum()
    taus = get_tau(tau,z_value)
    tau_max,tau_min = taus[0],taus[1]
    return {'tau_min': tau_min,'tau_max': tau_max,'z': z,'alpha': alpha},None

@jax.partial(jax.jit,static_argnums=(2,))
def map_row(z_input,alpha,T):
    z = (alpha - 1) * z_input

    tau_min,tau_max = jnp.min(z) - 1,jnp.max(z) - z.shape[0] ** (1 - alpha)
    result,_ = lax.scan(body,{'tau_min': tau_min,xs=None,length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z,alpha)
    return result / result.sum()

@jax.partial(jax.jit,static_argnums=(1,3,))
def _entmax(input,axis=-1,alpha=1.5,T=20):
    result = vmap(jax.partial(map_row,T),axis)(input)
    return result

@jax.partial(custom_jvp,nondiff_argnums=(1,2,))
def entmax(input,T=10):
    return _entmax(input,axis,T)

@jax.partial(jax.jit,static_argnums=(0,))    
def _entmax_jvp_impl(axis,T,primals,tangents):
    input = primals[0]
    Y = entmax(input,T)
    gppr = Y  ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q,axis=axis)
    dX -= q * gppr
    return Y,dX


@entmax.defjvp
def entmax_jvp(axis,tangents):
    return _entmax_jvp_impl(axis,tangents)

import numpy as np
input = jnp.array(np.random.randn(64,10)).block_until_ready()
weight = jnp.array(np.random.randn(64,10)).block_until_ready()

def toy(input,weight):
    return (weight*entmax(input,1.5,20)).sum()

jax.jit(value_and_grad(toy))(input,weight)

这导致(我希望)是最终错误,即

Non-hashable static arguments are not supported,as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.batching.BatchTracer'> for function map_row is non-hashable.

这很奇怪,因为我想我已经标记了每一个地方 axis 看起来是静态的,但它仍然告诉我它被跟踪了。

解决方法

当您编写带有位置参数的 partial 函数时,首先传递这些参数。所以这个:

jax.partial(map_row,alpha,T)

本质上等同于:

lambda z_input: map_row(alpha,T,z_input)

请注意参数的错误顺序——这就是导致您出错的原因:您将 z_input(一个不可散列的跟踪器)传递给一个预期为静态的参数。

您可以通过将上面的 partial 语句替换为:

lambda z: map_row(z,T)

然后您的代码将正确运行。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...