如何在Jax fori_loop机制中获得中间结果

问题描述

我是 Jax 的新手,也不是 Python 专家。

我在我的 Mac 笔记本电脑上运行 jax 版本“0.2.14”。请在下面找到一个简单的代码,它至少对我来说给出了一些结果。

但是,如注释 jax_metropolis_sampler 方法中所述,我想保存中间结果“位置”,但我不知道使用 jax_fori_loop 正确地做到这一点,我想像我一样完成肯定是可怕的。

我很确定有人可以给我一个更好的解决方案来利用 jax 并行性。目前,我还没有研究 MixtureModel_jax 的前向/后向差异。

提前致谢

import jax
import jax.numpy as jnp
from functools import partial

class MixtureModel_jax():
    def __init__(self,locs,scales,weights,*args,**kwargs):
        super().__init__(*args,**kwargs)
        self.loc = jnp.array([locs]).T
        self.scale = jnp.array([scales]).T
        self.weights = jnp.array([weights]).T
        norm = jnp.sum(self.weights)
        self.weights = self.weights/norm

        self.num_distr = len(locs)

    def pdf(self,x):
        probs = jax.scipy.stats.norm.pdf(x,loc=self.loc,scale=self.scale)
        return jnp.dot(self.weights.T,probs).squeeze()
        
    def logpdf(self,x):
        log_probs = jax.scipy.stats.norm.logpdf(x,scale=self.scale)
        return jax.scipy.special.logsumexp(np.log(self.weights) + log_probs,axis=0)

@partial(jax.jit,static_argnums=(1,))
def jax_metropolis_kernel(rng_key,logpdf,position,log_prob):
    key,subkey = jax.random.split(rng_key)
    """Moves the chain by one step using the Random Walk Metropolis algorithm."""
  
    move_proposals = jax.random.normal(key,shape=position.shape) * 0.1
    proposal = position + move_proposals
    proposal_log_prob = logpdf(proposal)

    log_uniform = jnp.log(jax.random.uniform(subkey))
    do_accept = log_uniform < proposal_log_prob - log_prob

    position = jnp.where(do_accept,proposal,position)
    log_prob = jnp.where(do_accept,proposal_log_prob,log_prob)
    return position,log_prob

@partial(jax.jit,2))
def jax_metropolis_sampler(rng_key,n_samples,initial_position):
    """Generate samples using the Random Walk Metropolis algorithm."""
    
    def mh_update(i,state):
        key,log_prob = state
        _,key = jax.random.split(key)
        new_position,new_log_prob = jax_metropolis_kernel(key,log_prob)
        return (key,new_position,new_log_prob)

    logp = logpdf(initial_position)

    # Just return the last position
    #    rng_key,log_prob = jax.lax.fori_loop(0,#                                                    mh_update,#                                                    (rng_key,initial_position,logp))
    #    return position

    
    # Porposal to save intermediate positions: slow and horrible I guess !
    spls = []
    state = (rng_key,logp)
    
    for i in range(n_samples):
        state = mh_update(i,state)
        spls.append(state[1])


    return spls

mixture_gaussian_model = MixtureModel_jax([0,1.5],[0.5,0.1],[8,2])


n_dim = 1
n_samples = 50
n_chains = 7
rng_key = jax.random.PRNGKey(42)

rng_keys = jax.random.split(rng_key,n_chains)
initial_position = jnp.zeros((n_dim,n_chains))

run_mcmc = jax.vmap(jax_metropolis_sampler,in_axes=(0,None,1),out_axes=0)
positions = run_mcmc(rng_keys,mixture_gaussian_modelbda x: mixture_gaussian_model.logpdf(x),initial_position)

print(len(positions))
print(positions[0].shape)

解决方法

执行此操作的最佳方法是在 fori_loop 函数中携带先前位置的列表。像这样:

def mh_update(i,state):
    key,positions,log_prob = state
    _,key = jax.random.split(key)
    new_position,new_log_prob = jax_metropolis_kernel(key,logpdf,positions[-1],log_prob)
    positions = jnp.vstack([positions,new_position])
    return (key,new_log_prob)

logp = logpdf(initial_position)
initial_state = (rng_key,initial_position[jnp.newaxis],logp)
rng_key,log_prob = jax.lax.fori_loop(0,n_samples,mh_update,initial_state)
return positions
,

这是我在@jakevdp 提示后设法得到的解决方案

@partial(jax.jit,static_argnums=(1,2))
def jax_metropolis_sampler(rng_key,initial_position):

       def mh_update_sol2(i,state):
        key,log_prob = state
        _,key = jax.random.split(key)
        new_position,positions[i-1],log_prob)
        positions=positions.at[i].set(new_position)
        return (key,new_log_prob)


    logp = logpdf(initial_position)
    all_positions = jnp.zeros((n_samples,)+initial_position.shape)
    initial_state = (rng_key,all_positions,logp)
    rng_key,log_prob = jax.lax.fori_loop(1,mh_update_sol2,initial_state)
    
    
    return all_positions

n_dim = 1
n_samples = 100_000
n_chains = 100
rng_key = jax.random.PRNGKey(42)

rng_keys = jax.random.split(rng_key,n_chains)
initial_position = jnp.zeros((n_dim,n_chains))

run_mcmc = jax.vmap(jax_metropolis_sampler,in_axes=(0,None,1),out_axes=0)
all_positions = run_mcmc(rng_keys,lambda x: mixture_gaussian_model.logpdf(x),initial_position)

all_positions=all_positions.squeeze()

 

然后,在您可以绘制 100 个链之后...

x_axis = jnp.arange(-3,3,0.001)
for i in range(all_positions.shape[0]):
    plt.hist(all_positions[i],bins=50,density=True,histtype='step',label=f"chain [{i}]");
plt.plot(x_axis,mixture_gaussian_model.pdf(x_axis),'r-',lw=5,alpha=0.6,label='true pdf')
plt.legend()
plt.show()

enter image description here

感谢您的帮助。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...