【问题标题】:How to get intermediate results in Jax fori_loop mechanism如何在 Jax fori_loop 机制中获取中间结果
【发布时间】:2021-07-08 14:08:57
【问题描述】:

我是 Jax 的新手,也不是 Python 专家。

我在我的 mac 笔记本电脑上运行 jax 版本“0.2.14”。请在下面找到一个简单的代码,至少对我来说给出了一些结果。

但是,正如评论 jax_metropolis_sampler 方法中所述,我想保存中间结果“位置”,但我不知道使用 jax_fori_loop 正确地做到这一点,我想像我所做的那样肯定是可怕的.

我很确定有人可以给我一个更好的利用 jax 并行性的解决方案。暂时我还没有研究 MixtureModel_jax 的前向/后向差异。

提前致谢

import jax
import jax.numpy as jnp
from functools import partial

class MixtureModel_jax():
    def __init__(self, locs, scales, weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loc = jnp.array([locs]).T
        self.scale = jnp.array([scales]).T
        self.weights = jnp.array([weights]).T
        norm = jnp.sum(self.weights)
        self.weights = self.weights/norm

        self.num_distr = len(locs)

    def pdf(self, x):
        probs = jax.scipy.stats.norm.pdf(x,loc=self.loc, scale=self.scale)
        return jnp.dot(self.weights.T,probs).squeeze()
        
    def logpdf(self, x):
        log_probs = jax.scipy.stats.norm.logpdf(x,loc=self.loc, scale=self.scale)
        return jax.scipy.special.logsumexp(np.log(self.weights) + log_probs, axis=0)

@partial(jax.jit, static_argnums=(1,))
def jax_metropolis_kernel(rng_key, logpdf, position, log_prob):
    key, subkey = jax.random.split(rng_key)
    """Moves the chain by one step using the Random Walk Metropolis algorithm."""
  
    move_proposals = jax.random.normal(key, shape=position.shape) * 0.1
    proposal = position + move_proposals
    proposal_log_prob = logpdf(proposal)

    log_uniform = jnp.log(jax.random.uniform(subkey))
    do_accept = log_uniform < proposal_log_prob - log_prob

    position = jnp.where(do_accept, proposal, position)
    log_prob = jnp.where(do_accept, proposal_log_prob, log_prob)
    return position, log_prob

@partial(jax.jit, static_argnums=(1, 2))
def jax_metropolis_sampler(rng_key, n_samples, logpdf, initial_position):
    """Generate samples using the Random Walk Metropolis algorithm."""
    
    def mh_update(i, state):
        key, position, log_prob = state
        _, key = jax.random.split(key)
        new_position, new_log_prob = jax_metropolis_kernel(key, logpdf, position, log_prob)
        return (key, new_position, new_log_prob)

    logp = logpdf(initial_position)

    # Just return the last position
    #    rng_key, position, log_prob = jax.lax.fori_loop(0, n_samples, 
    #                                                    mh_update, 
    #                                                    (rng_key, initial_position, logp))
    #    return position

    
    # Porposal to save intermediate positions: slow and horrible I guess !
    spls = []
    state = (rng_key, initial_position, logp)
    
    for i in range(n_samples):
        state = mh_update(i, state)
        spls.append(state[1])


    return spls

mixture_gaussian_model = MixtureModel_jax([0,1.5],[0.5,0.1],[8,2])


n_dim = 1
n_samples = 50
n_chains = 7
rng_key = jax.random.PRNGKey(42)

rng_keys = jax.random.split(rng_key, n_chains)
initial_position = jnp.zeros((n_dim, n_chains))

run_mcmc = jax.vmap(jax_metropolis_sampler, 
                    in_axes=(0, None, None, 1),
                    out_axes=0)
positions = run_mcmc(rng_keys, n_samples, 
                 mixture_gaussian_modelbda x: mixture_gaussian_model.logpdf(x), 
                     initial_position)

print(len(positions))
print(positions[0].shape)

【问题讨论】:

    标签: python jax


    【解决方案1】:

    最好的方法是在fori_loop 函数中携带先前位置的列表。像这样的:

    def mh_update(i, state):
        key, positions, log_prob = state
        _, key = jax.random.split(key)
        new_position, new_log_prob = jax_metropolis_kernel(key, logpdf, positions[-1], log_prob)
        positions = jnp.vstack([positions, new_position])
        return (key, positions, new_log_prob)
    
    logp = logpdf(initial_position)
    initial_state = (rng_key, initial_position[jnp.newaxis], logp)
    rng_key, positions, log_prob = jax.lax.fori_loop(0, n_samples, 
                                                     mh_update, 
                                                     initial_state)
    return positions
    

    【讨论】:

    • 好吧@jakedp 我试过了,这里是错误:``` TypeError: scan carry output and input must have the same types, got (ShapedArray(int32[], weak_type=True), (ShapedArray (uint32[2]), ShapedArray(float32[2,1]), ShapedArray(float32[1]))) 和 (ShapedArray(int32[], weak_type=True), (ShapedArray(uint32[2]), ShapedArray( float32[1,1]),ShapedArray(float32[1])))。 ``` 所以我猜“仓位”大小的逐渐增加是被 lax.fori_loop 禁止的,如文档中所述。你怎么看?
    • 啊,有道理。要解决它,您可以在开始时初始化完整的位置数组并在每次迭代时填充一行
    【解决方案2】:

    这是我在@jakevdp 提示后设法得到的解决方案

    @partial(jax.jit, static_argnums=(1, 2))
    def jax_metropolis_sampler(rng_key, n_samples, logpdf, initial_position):
    
           def mh_update_sol2(i, state):
            key, positions, log_prob = state
            _, key = jax.random.split(key)
            new_position, new_log_prob = jax_metropolis_kernel(key, logpdf, positions[i-1], log_prob)
            positions=positions.at[i].set(new_position)
            return (key, positions, new_log_prob)
    
    
        logp = logpdf(initial_position)
        all_positions = jnp.zeros((n_samples,)+initial_position.shape)
        initial_state = (rng_key,all_positions, logp)
        rng_key, all_positions, log_prob = jax.lax.fori_loop(1, n_samples, 
                                                     mh_update_sol2, 
                                                     initial_state)
        
        
        return all_positions
    
    n_dim = 1
    n_samples = 100_000
    n_chains = 100
    rng_key = jax.random.PRNGKey(42)
    
    rng_keys = jax.random.split(rng_key, n_chains)
    initial_position = jnp.zeros((n_dim, n_chains))
    
    run_mcmc = jax.vmap(jax_metropolis_sampler, 
                        in_axes=(0, None, None, 1),
                        out_axes=0)
    all_positions = run_mcmc(rng_keys, n_samples, 
                         lambda x: mixture_gaussian_model.logpdf(x), 
                         initial_position)
    
    all_positions=all_positions.squeeze()
    
     
    

    那么,在你可以画出 100 条链之后……

    x_axis = jnp.arange(-3, 3, 0.001)
    for i in range(all_positions.shape[0]):
        plt.hist(all_positions[i],bins=50, density=True, histtype='step',label=f"chain [{i}]");
    plt.plot(x_axis,  mixture_gaussian_model.pdf(x_axis),'r-', lw=5, alpha=0.6, label='true pdf')
    plt.legend()
    plt.show()
    

    感谢您的帮助。

    【讨论】:

      猜你喜欢
      • 2021-08-07
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2014-11-03
      • 2019-12-21
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多