【问题标题】:Why does jax.numpy.dot() run slower than numpy.dot() on CPU?为什么 jax.numpy.dot() 在 CPU 上运行比 numpy.dot() 慢?
【发布时间】:2020-08-31 13:54:45
【问题描述】:

我想使用 JAX 在 CPU 上加速我的 numpy 代码,然后在 GPU 上。这是我在本地计算机上运行的示例代码(仅 CPU):

import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time

size = 3000

key = random.PRNGKey(0)
x =  random.normal(key, (size,size), dtype=jnp.float64)

start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))

x2=np.random.normal((size,size))

start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))

我收到警告,时间成本如下:

/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: 
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s

我在这里做错了吗? JAX 是否也应该在 CPU 上加速 numpy 代码?

【问题讨论】:

  • 机会很高,numpy 正在使用 (Open-)BLAS 并且对于 np.dot() 没有太多可优化的地方。
  • @sascha 但是 JAX 比 NumPy 慢很多是没有意义的。我还没弄清楚原因。
  • 这并不让我感到惊讶。点(向量或 matmul;无关紧要)对于每种 cpu-arch 都是完全手动编码的,您不会使用自动编译器击败它。如果没有太多关于 JAX 的知识,它可能是关于调度、优化临时变量和其他一些东西 -> 当多个“内核”融合时会产生出色的代码。但是 dot 是如此的基本,以至于没有机会达到它。补充说明:我提到了openblas:但在某些dist中,使用了Intels MKL:您是否希望某些自动编译器在您的(也许)intel cpu上击败手动编码的matmul代码(由intel-devs)?
  • 另请阅读:this
  • 也许 numpy 也快得多,因为x 的形状是(3000, 3000)x2 的形状是(2,) :)

标签: python numpy


【解决方案1】:

Jax 和 Numpy 之间可能存在性能差异,但在原始帖子中,时间差异主要归结为数组创建中的错误。 Jax 使用的数组的形状为 3000x3000,而 Numpy 使用的数组是长度为 2 的一维数组。numpy.random.normal 的第一个参数是loc(即从中采样的高斯平均值)。应使用关键字参数size= 来指示数组的形状。

numpy.random.normal(loc=0.0, scale=1.0, size=None)

进行此更改后,Jax 和 Numpy 之间的性能差异将减少。

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)

start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))

x2 = np.random.normal(size=(size, size)).astype(np.float64)

start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))

一次运行的输出是

Time of jnp: 2.3315 s
Time of np: 2.8811 s

在测量定时性能时,应该收集多次运行,因为函数的性能是时间的分布而不是单个值。这可以通过 Python 标准库 timeit.timeit 函数或 IPython 和 Jupyter Notebook 中的 %timeit 魔法来完成。

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Numpy 中似乎有针对 32 位浮点数的优化点运算。

【讨论】:

  • 非常感谢。我发现唯一不同的是时间表现。在我的笔记本电脑(8 核)上,numpy.dot 需要 0.3176 秒来运行代码(dtype = float64),这比 jax.numpy.dot(0.4529 秒)要快。如果我将 dtype 更改为 float32,numpy.dot 会更快,只需要 0.2112 秒,但 jax.numpy.dot 仍然需要 0.4537 秒。
猜你喜欢
  • 2020-04-14
  • 2021-02-09
  • 1970-01-01
  • 1970-01-01
  • 2021-02-01
  • 2016-10-23
  • 2012-08-17
  • 1970-01-01
  • 2016-08-08
相关资源
最近更新 更多