【发布时间】:2021-11-03 07:45:08
【问题描述】:
这是一个基本示例。
@jax.jit
def block(arg1, arg2):
for x1 in range(cons1):
for x2 in range(cons2):
for x3 in range(cons3):
--do something--
return result
当 cons 较小时,编译时间约为一分钟。使用更大的缺点,编译时间要长得多——10 分钟。我需要更高的缺点。可以做什么? 从我正在阅读的内容来看,循环是原因。它们在编译时展开。 有什么解决方法吗?还有 jax.fori_loop。但我不明白如何使用它。有 jax.experimental.loops 模块,但我还是无法理解它。
我对这一切都很陌生。因此,感谢所有帮助。 如果您能提供一些如何使用 jax 循环的示例,将不胜感激。
另外,什么是好的编译时间?几分钟内就可以了吗? 在其中一个示例中,编译时间为 262 秒,剩余运行时间约为 0.1-0.2 秒。
运行时的任何收益都会被编译时间所掩盖。
【问题讨论】:
标签: python performance loops jit jax