问题描述
这与this question有关。我设法充分利用了代码,除了一件奇怪的事情。
<nav>
<ul>
<li><a class="one" href="#">HOME</a></li>
<li><a class="one" href="#">ABOUT</a></li>
<li style="float: right"><a class="one" href="#">DONATE</a></li>
<li style="float: right"><a class="one" href="#">MAGAZINE</a></li>
</ul>
</nav>
import jax.numpy as jnp
from jax import grad,jit,value_and_grad
from jax import vmap,pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp
def p_tau(z,tau,alpha=1.5):
return jnp.clip((alpha - 1) * z - tau,a_min=0) ** (1 / (alpha - 1))
def get_tau(tau,tau_max,tau_min,z_value):
return lax.cond(z_value < 1,lambda _: (tau,tau_min),lambda _: (tau_max,tau),operand=None
)
def body(kwargs,x):
tau_min = kwargs['tau_min']
tau_max = kwargs['tau_max']
z = kwargs['z']
alpha = kwargs['alpha']
tau = (tau_min + tau_max) / 2
z_value = p_tau(z,alpha).sum()
taus = get_tau(tau,z_value)
tau_max,tau_min = taus[0],taus[1]
return {'tau_min': tau_min,'tau_max': tau_max,'z': z,'alpha': alpha},None
@jax.partial(jax.jit,static_argnums=(2,))
def map_row(z_input,alpha,T):
z = (alpha - 1) * z_input
tau_min,tau_max = jnp.min(z) - 1,jnp.max(z) - z.shape[0] ** (1 - alpha)
result,_ = lax.scan(body,{'tau_min': tau_min,xs=None,length=T)
tau = (result['tau_max'] + result['tau_min']) / 2
result = p_tau(z,alpha)
return result / result.sum()
@jax.partial(jax.jit,static_argnums=(1,3,))
def _entmax(input,axis=-1,alpha=1.5,T=20):
result = vmap(jax.partial(map_row,alpha=alpha,T=T),axis)(input)
return result
@jax.partial(custom_jvp,nondiff_argnums=(1,2,))
def entmax(input,T=10):
return _entmax(input,axis,T)
@jax.partial(jax.jit,static_argnums=(0,))
def _entmax_jvp_impl(axis,T,primals,tangents):
input = primals[0]
Y = entmax(input,T)
gppr = Y ** (2 - alpha)
grad_output = tangents[0]
dX = grad_output * gppr
q = dX.sum(axis=axis) / gppr.sum(axis=axis)
q = jnp.expand_dims(q,axis=axis)
dX -= q * gppr
return Y,dX
@entmax.defjvp
def entmax_jvp(axis,tangents):
return _entmax_jvp_impl(axis,tangents)
import numpy as np
input = jnp.array(np.random.randn(64,10)).block_until_ready()
weight = jnp.array(np.random.randn(64,10)).block_until_ready()
def toy(input,weight):
return (weight*entmax(input,axis=0,T=20)).sum()
jax.jit(value_and_grad(toy))(input,weight)
这是由这行代码引起的
tuple index out of range
即使我只用实体函数替换函数体,错误仍然存在。这是一个非常奇怪的行为。但是,让这个东西保持静态对我来说非常重要,因为它有助于展开循环。
解决方法
这个错误是由于我希望能很快在 JAX 中修复的一个问题:静态参数不能通过关键字传递。换句话说,你应该改变这个:
def toy(input,weight):
return (weight*entmax(input,axis=0,alpha=1.5,T=20)).sum()
为此:
def toy(input,1.5,20)).sum()
对 max_row
的调用应应用相同的修复。
此时,由于将跟踪变量传递给需要静态参数的函数,您最终会遇到 ValueError;解决方案将类似于 How to handle JAX reshape with JIT 中的解决方案。
另外一个注意事项:这个 static_argnums
错误最近得到了改进,在下一个版本中会更清楚一些:
ValueError: jitted function has static_argnums=(2,),donate_argnums=() but was called with only 1 positional arguments.