【问题标题】:Efficient use of JAX for conditional function evaluation based on an array of integersEfficient use of JAX for conditional function evaluation based on an array of integers
【发布时间】:2023-02-12 09:33:42
【问题描述】:

我想根据一个整数数组和其他以实数作为这些函数输入的数组来有效地执行条件函数评估。我希望找到一种基于 JAX 的解决方案,该解决方案能够比我在下面描述的 for 循环方法提供显着的性能改进:

import jax
from jax import vmap;
import jax.numpy as jnp
import jax.random as random

def g_0(x, y, z, u):
    return x + y + z + u

def g_1(x, y, z, u):
    return x * y * z * u

def g_2(x, y, z, u):
    return x - y + z - u

def g_3(x, y, z, u):
    return x / y / z / u

g_i = [g_0, g_1, g_2, g_3]
g_i_jit = [jax.jit(func) for func in g_i]

def g_git(i, x, y, z, u):
    return g_i_jit[i](x=x, y=y, z=z, u=u)

def g(i, x, y, z, u):
    return g_i[i](x=x, y=y, z=z, u=u)


len_xyz = 3000
x_ar = random.uniform(random.PRNGKey(0), shape=(len_xyz,))
y_ar = random.uniform(random.PRNGKey(1), shape=(len_xyz,))
z_ar = random.uniform(random.PRNGKey(2), shape=(len_xyz,))

len_u = 1000
u_0 = random.uniform(random.PRNGKey(3), shape=(len_u,))
u_1 = jnp.repeat(u_0, len_xyz)
u_ar = u_1.reshape(len_u, len_xyz)


len_i = 50
i_ar = random.randint(random.PRNGKey(5), shape=(len_i,), minval=0, maxval= len(g_i)) #related to g_range-1


total = jnp.zeros((len_u, len_xyz))

for i in range(len_i):
    total= total + g_git(i_ar[i], x_ar, y_ar, z_ar, u_ar)

“i_ar”的作用是充当索引,从列表g_i中选择四个函数之一。 “i_ar”是一个整数数组,每个整数代表 g_i 列表中的一个索引。另一方面,x_ar、y_ar、z_ar 和 u_ar 是实数数组,它们是 i_ar 选择的函数的输入。

我怀疑 i_ar 和 x_ar、y_ar、z_ar 和 u_ar 之间的这种本质差异可能很难找到一种 JAX 方式来更有效地替换上面的 for 循环。关于如何使用 JAX(或其他东西)替换 for 循环以更有效地获取“总计”的任何想法?

我天真地尝试过使用 vmap:

g_git_vmap = jax.vmap(g_git)
total = jnp.zeros((len_u, len_xyz))
total = jnp.sum(g_git_vmap(i_ar, x_ar, y_ar, z_ar, u_ar), axis=0)

但这导致错误消息并导致无处可去。

【问题讨论】:

    标签: numpy multidimensional-array conditional-statements vectorization jax


    【解决方案1】:

    可能最好的方法是使用lax.switch,它允许基于索引数组在多个函数之间动态切换。

    这是您的原始函数与基于 lax.switch 的方法的比较,以及 Colab GPU 运行时的计时:

    def f_original(i, x, y, z, u):
      total = jnp.zeros((len(u), len(x)))
      for i in range(len(i)):
        total= total + g_git(i_ar[i], x, y, z, u)
      return total
    
    @jax.jit
    def f_switch(i, x, y, z, u):
      g = lambda i: jax.lax.switch(i, g_i, x, y, z, u)
      return jax.vmap(g)(i).sum(0)
    
    out1 = f_original(i_ar, x_ar, y_ar, z_ar, u_ar)
    out2 = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
    np.testing.assert_allclose(out1, out2, rtol=5E-3)
    
    %timeit f_original(i_ar, x_ar, y_ar, z_ar, u_ar).block_until_ready()
    # 71 ms ± 23.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    %timeit f_switch(i_ar, x_ar, y_ar, z_ar, u_ar).block_until_ready()
    # 4.69 ms ± 37.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    

    【讨论】:

      猜你喜欢
      • 2022-12-01
      • 2022-12-02
      • 2022-12-01
      • 2022-12-01
      • 2022-12-19
      • 2022-11-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多