在 jax vmap 函数中调试数组

问题描述

亲爱的 jax 专家,我需要您的帮助。

这是一个工作示例(我已经按照建议来简化我的代码,虽然我不是 jax 和 Python 方面的专家,以猜测 vmap 所涉及的机制的核心是什么)

def jax_kernel(rng_key,logpdf,position,log_prob):
    key,subkey = jax.random.split(rng_key)
    move_proposals = jax.random.normal(key,shape=position.shape)* 0.1   
    proposal = position + move_proposals
    proposal_log_prob = logpdf(proposal)
    return proposal,proposal_log_prob

def jax_sampler(rng_key,n_samples,initial_position):
    
    def mh_update(i,state):
        key,positions,log_prob = state
        _,key = jax.random.split(key)        
        print(f"mh_update: positions[{i-1}]:",jnp.asarray(positions[i-1]))
        new_position,new_log_prob = jax_kernel(key,positions[i-1],log_prob)
            
        positions=positions.at[i].set(new_position)
        return (key,new_log_prob)
    
    # all positions structure should be set before lax.fori_loop
    print("initial_position shape:",initial_position.shape)       
    all_positions = jnp.zeros((n_samples,)+initial_position.shape)  
    all_positions=all_positions.at[0,0].set(1.)
    all_positions=all_positions.at[0,1].set(2.)
    all_positions=all_positions.at[0,2].set(2.)
    print("all_positions init:",all_positions.shape)
    logp = logpdf(all_positions[0])
    
    # use of a for-loop to be able to debug mh_update instead of a jax.fori_loop
    initial_state = (rng_key,all_positions,logp)
    val = initial_state
    for i in range(1,n_samples):
        val = mh_update(i,val)
    rng_key,log_prob = val
    # return all the positions of the parameters (n_chains,n_dim)
    return all_positions

def func(par):
    xi = jnp.asarray(sci_stats.uniform.rvs(size=10))
    val = xi*par[1]+par[0]
    return jnp.sum(jax.scipy.stats.norm.logpdf(x=val,loc=yi,scale=par[2]))
    

n_dim = 3          # number of parameters ie. (a,b,s)
n_samples = 5      # number of samples per chain
n_chains = 4       # number of MCMC chains
rng_key = jax.random.PRNGKey(42)
rng_keys = jax.random.split(rng_key,n_chains) 
initial_position = jnp.ones((n_dim,n_chains))                      
print("main initial_position shape",initial_position.shape)
run = jax.vmap(jax_sampler,in_axes=(0,None,1),out_axes=0) 
all_positions = run(rng_keys,lambda p: func(p),initial_position)
print("all_positions:",all_positions)

然后我的问题是关于维度演化 print(f"mh_update: positions[{i-1}]:",jnp.asarray(positions[i-1]))。我不明白为什么positions[i-1]从维度 n_dim 开始,然后切换到 n_chains x n_dim

提前感谢您的评论

这里是完整的输出

main initial_position shape (3,4)
initial_position shape: (3,)
all_positions init: (5,3)
mh_update: positions[0]: [1. 2. 2.]
mh_update: positions[1]: Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
  with val = DeviceArray([[0.9354116,1.7876872,1.8443539 ],[0.9844745,2.073029,1.9511036 ],[0.98202926,2.0109322,2.094176  ],[0.9536771,1.9731759,2.093319  ]],dtype=float32)
       batch_dim = 0
mh_update: positions[2]: Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
  with val = DeviceArray([[1.0606856,1.6707807,1.8377957],[1.0465866,1.9754674,1.7009288],[1.1107644,2.0142047,2.190575 ],[1.0089972,1.9953227,1.996874 ]],dtype=float32)
       batch_dim = 0
mh_update: positions[3]: Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
  with val = DeviceArray([[1.0731456,1.644405,2.1343162],[1.0599504,2.0121546,1.6867112],[1.0585173,1.9661485,2.1573594],[1.1213307,1.9335203,1.9683584]],dtype=float32)
       batch_dim = 0
all_positions: [[[1.         2.         2.        ]
  [0.9354116  1.7876872  1.8443539 ]
  [1.0606856  1.6707807  1.8377957 ]
  [1.0731456  1.644405   2.1343162 ]
  [1.0921828  1.5742197  2.058759  ]]

 [[1.         2.         2.        ]
  [0.9844745  2.073029   1.9511036 ]
  [1.0465866  1.9754674  1.7009288 ]
  [1.0599504  2.0121546  1.6867112 ]
  [1.0835105  2.0051234  1.4766487 ]]

 [[1.         2.         2.        ]
  [0.98202926 2.0109322  2.094176  ]
  [1.1107644  2.0142047  2.190575  ]
  [1.0585173  1.9661485  2.1573594 ]
  [1.1728328  1.981367   2.180744  ]]

 [[1.         2.         2.        ]
  [0.9536771  1.9731759  2.093319  ]
  [1.0089972  1.9953227  1.996874  ]
  [1.1213307  1.9335203  1.9683584 ]
  [1.1148386  1.9598911  2.1721165 ]]]

解决方法

在第一次迭代中,您打印在 vmapped 函数中构建的具体数组。它是一个形状为 float32(3,) 数组。

在第一次迭代之后,您已经通过对 vmap 输入的操作构造了一个新数组。当您 vmap 这样的输入时,JAX 将您的输入数组替换为 tracer,它是您输入的抽象表示;打印值如下所示:

Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
  with val = DeviceArray([[1.0731456,1.644405,2.1343162],[1.0599504,2.0121546,1.6867112],[1.0585173,1.9661485,2.1573594],[1.1213307,1.9335203,1.9683584]],dtype=float32)

float32[3] 表示此跟踪器表示形状为 (3,) 的 float32 值数组:也就是说,它仍然具有与第一次迭代中相同的类型和形状。但在这种情况下,它不是一个包含三个元素的具体数组,它是一个批处理跟踪器,表示 vmapped 输入的每次迭代。 vmap 转换的强大之处在于,JAX 在一次传递中有效地跟踪了 vmapped 计算的所有 隐含迭代:在跟踪器表示中,val 的行有效地向您展示了所有 vmapped 迭代的中间值。

要进一步了解 JAX 跟踪的工作原理,请阅读 JAX 文档中的 How To Think In JAX

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...