【问题标题】:How to reduce JAX compile time when using for loop?使用 for 循环时如何减少 JAX 编译时间?
【发布时间】:2021-11-03 07:45:08
【问题描述】:

这是一个基本示例。

@jax.jit
def block(arg1, arg2):
   for x1 in range(cons1):
       for x2 in range(cons2):
          for x3 in range(cons3):
             --do something--
   return result

当 cons 较小时,编译时间约为一分钟。使用更大的缺点,编译时间要长得多——10 分钟。我需要更高的缺点。可以做什么? 从我正在阅读的内容来看,循环是原因。它们在编译时展开。 有什么解决方法吗?还有 jax.fori_loop。但我不明白如何使用它。有 jax.experimental.loops 模块,但我还是无法理解它。

我对这一切都很陌生。因此,感谢所有帮助。 如果您能提供一些如何使用 jax 循环的示例,将不胜感激。

另外,什么是好的编译时间?几分钟内就可以了吗? 在其中一个示例中,编译时间为 262 秒,剩余运行时间约为 0.1-0.2 秒。

运行时的任何收益都会被编译时间所掩盖。

【问题讨论】:

    标签: python performance loops jit jax


    【解决方案1】:

    我不确定这是否与numba 相同,但这可能是类似的情况。

    当我使用numba.jit编译器并且有大数据输入时,我首先在一些小的示例数据上编译函数,然后使用它。

    伪代码:

    func_being_compiled(small_amount_of_data)  # compile-only purpose
    func_being_compiled(large_amount_of_data)
    
    

    【讨论】:

      【解决方案2】:

      JAX 的 JIT 编译器会扁平化所有 Python 循环。要了解我的意思,请看一下通过jax.make_jaxpr 运行的这个简单函数,它是一种检查 JAX 的跟踪器如何解释 python 代码的方法(请参阅Understanding Jaxprs 了解更多信息):

      import jax
      
      def f(x):
        for i in range(5):
          x += i
        return x
      
      print(jax.make_jaxpr(f)(0))
      # { lambda  ; a.
      #   let b = add a 0
      #       c = add b 1
      #       d = add c 2
      #       e = add d 3
      #       f = add e 4
      #   in (f,) }
      

      请注意,循环是扁平的:每一步都成为发送到 XLA 编译器的显式操作。 XLA 编译时间会随着函数中操作数量的增加而增加,因此三重嵌套的 for 循环会导致编译时间变长是有道理的。

      那么,如何解决这个问题?好吧,不幸的是,答案取决于你的 --do something-- 正在做什么,所以我猜不出来。

      一般来说,最好的选择是使用向量化数组操作,而不是循环这些向量中的值;例如,这是一种添加两个向量的非常慢的方法:

      import jax.numpy as jnp
      
      def f_slow(x, y):
        z = []
        for xi, yi in zip(xi, yi):
          z.append(xi + yi)
        return jnp.array(z)
      

      这是一种更快的方法来做同样的事情:

      def f_fast(x, y):
        return x + y
      

      如果您的操作不适合矢量化,另一种选择是使用 lax control flow 运算符代替 for 循环:这会将循环向下推入 XLA。这在 CPU 上可以有相当好的性能,但与等效的向量化数组操作相比,在加速器上的速度较慢。

      有关 JAX 和 Python 控制流语句(如forifwhile 等)的更多讨论,请参阅? JAX - The Sharp Bits ?: Control Flow

      【讨论】:

      • 对于无法向量化的操作,jax.lax.fori_loop 与 python for 循环相比显着减少了编译时间。而且,确实,它不需要减少计算时间。
      猜你喜欢
      • 2016-07-04
      • 2013-10-21
      • 1970-01-01
      • 2018-03-27
      • 2022-08-17
      • 1970-01-01
      • 2019-10-07
      • 2014-06-06
      • 2011-01-16
      相关资源
      最近更新 更多