问题描述
我正在尝试区分一个函数,该函数近似包含在 2 个范围内的高斯分数(截断的高斯),给定均值偏移。
jnp.grad
不允许我区分添加布尔过滤器(注释行),所以我不得不用 sigmoid 代替。
但是,现在当截断边界很高时梯度始终为 nan,我不明白为什么。
在下面的示例中,我正在计算平均值为 0 且 std=1 的高斯梯度,然后我用 x
对其进行移动。
如果我减少边界,那么函数会按预期运行。但这不是解决方案。
当边界较高时,belows
一直变为 1。但是如果是这种情况并且 x
对下面没有影响,那么它对梯度的贡献应该是 0 而不是 nan。但是如果我返回 belows[0][0]
而不是 jnp.mean(filt,axis=0)
,我仍然得到 nan
。
有什么想法吗? 提前致谢(github 上也有一个问题)
import os
from tqdm import tqdm
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' # Use 8 cpu devices
import numpy as np
from jax.config import config
config.update("jax_enable_x64",True)
import jax
import jax.numpy as jnp
from jax import vmap
from functools import reduce
def sigmoid(x,scale=100):
return 1 / (1 + jnp.exp(-x*scale))
def above_lower(x,l,scale=100):
return sigmoid(x - l,scale)
def below_upper(x,u,scale=100):
return 1 - sigmoid(x - u,scale)
def combine_soft_filters(a):
return jnp.prod(jnp.stack(a),axis=0)
def fraction_not_truncated(mu,v,limits,stdnorm_samples):
L = jnp.linalg.cholesky(v)
y = vmap(lambda x: jnp.dot(L,x))(stdnorm_samples) + mu
# filt = reduce(jnp.logical_and,[(y[...,i] > l) & (y[...,i] < u) for i,(l,u) in enumerate(limits)])
aboves = [above_lower(y[...,i],l) for i,u) in enumerate(limits)]
belows = [below_upper(y[...,u) for i,u) in enumerate(limits)]
filt = combine_soft_filters(aboves+belows)
return jnp.mean(filt,axis=0)
limits = np.array([
[0.,1000],])
stdnorm_samples = np.random.multivariate_normal([0],np.eye(1),size=1000)
def func(x):
return fraction_not_truncated(jnp.zeros(1)+x,jnp.eye(1),stdnorm_samples)
_x = np.linspace(-2,2,500)
gradfunc = jax.grad(func)
vals = [func(x) for x in tqdm(_x)]
grads = [gradfunc(x) for x in tqdm(_x)]
print(vals)
print(grads)
import matplotlib.pyplot as plt
plt.plot(_x,np.asarray(vals))
plt.ylabel('f(x)')
plt.twinx()
plt.plot(_x,np.asarray(grads),c='r')
plt.ylabel("f(x)'")
plt.title('Fraction not truncated')
plt.axhline(0,color='k',alpha=0.2)
plt.xlabel('shift')
plt.tight_layout()
plt.show()
[DeviceArray(1.,dtype=float64),DeviceArray(1.,dtype=float64)]
[DeviceArray(nan,DeviceArray(nan,dtype=float64)]
解决方法
问题在于您的 sigmoid
函数的实现方式使得自动确定的梯度对于 x
的大负值不稳定:
import jax.numpy as jnp
import jax
def sigmoid(x,scale=100):
return 1 / (1 + jnp.exp(-x*scale))
print(jax.grad(sigmoid)(-1000.0))
# nan
您可以使用 jax.make_jaxpr
函数内省自动确定的梯度产生的操作(注释是我的注释),了解为什么会发生这种情况:
>>> jax.make_jaxpr(jax.grad(sigmoid))(-1000.0)
{ lambda ; a. # a = -1000
let b = neg a # b = 1000
c = mul b 100.0 # c = 100,000
d = exp c # d = inf
e = add d 1.0
_ = div 1.0 e
f = integer_pow[ y=-2 ] e # f = 0
g = mul 1.0 f # g = 0
h = mul g 1.0 # h = 0
i = neg h # i = 0
j = mul i d # j = 0 * inf = NaN
k = mul j 100.0 # k = NaN
l = neg k # l = NaN
in (l,) } # return NaN
这是 64 位浮点运算失败的情况之一:它没有处理 exp(100000)
这样的数字的范围。
那你能做什么?一个重量级的选项是使用 custom derivative rule 来告诉 autodiff 如何以更稳定的方式处理 sigmoid
函数。但是,在这种情况下,一个更简单的选择是根据在 autodiff 转换下表现更好的东西重新表达 sigmoid
函数。一种选择是这样的:
def sigmoid(x,scale=100):
return 0.5 * (jnp.tanh(x * scale / 2) + 1)
在脚本中使用此版本可解决问题。