【发布时间】:2020-08-31 13:54:45
【问题描述】:
我想使用 JAX 在 CPU 上加速我的 numpy 代码,然后在 GPU 上。这是我在本地计算机上运行的示例代码(仅 CPU):
import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time
size = 3000
key = random.PRNGKey(0)
x = random.normal(key, (size,size), dtype=jnp.float64)
start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))
x2=np.random.normal((size,size))
start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))
我收到警告,时间成本如下:
/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130:
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s
我在这里做错了吗? JAX 是否也应该在 CPU 上加速 numpy 代码?
【问题讨论】:
-
机会很高,numpy 正在使用 (Open-)BLAS 并且对于
np.dot()没有太多可优化的地方。 -
@sascha 但是 JAX 比 NumPy 慢很多是没有意义的。我还没弄清楚原因。
-
这并不让我感到惊讶。点(向量或 matmul;无关紧要)对于每种 cpu-arch 都是完全手动编码的,您不会使用自动编译器击败它。如果没有太多关于 JAX 的知识,它可能是关于调度、优化临时变量和其他一些东西 -> 当多个“内核”融合时会产生出色的代码。但是 dot 是如此的基本,以至于没有机会达到它。补充说明:我提到了openblas:但在某些dist中,使用了Intels MKL:您是否希望某些自动编译器在您的(也许)intel cpu上击败手动编码的matmul代码(由intel-devs)?
-
另请阅读:this
-
也许 numpy 也快得多,因为
x的形状是(3000, 3000),x2的形状是(2,):)