【发布时间】:2021-10-05 12:06:11
【问题描述】:
我正在尝试了解jax.vmap/pmap(jax:https://jax.readthedocs.io/)的行为。 CUDA有threadId让你知道哪个线程在执行代码,jax中有没有类似的概念? (jax.process_id 不是)
【问题讨论】:
标签: jax
我正在尝试了解jax.vmap/pmap(jax:https://jax.readthedocs.io/)的行为。 CUDA有threadId让你知道哪个线程在执行代码,jax中有没有类似的概念? (jax.process_id 不是)
【问题讨论】:
标签: jax
不,JAX 中没有真正类似于 CUDA 线程 ID 的东西。 XLA 编译器在较低级别处理有关 GPU 线程分配的详细信息,我不知道有任何直接的 API 可以将此信息返回到 JAX 的 Python 运行时。
JAX 确实提供更高级别的设备分配处理的一种情况是使用pmap;在这种情况下,如果您想要依赖于执行映射代码的设备的逻辑,您可以将一组设备 ID 显式传递给 pmapped 函数。例如,我在一个 8 设备系统上运行了以下代码:
import jax
import jax.numpy as jnp
num_devices = jax.device_count()
def f(device, data):
return data + device
device_index = jnp.arange(num_devices)
data = jnp.zeros((num_devices, 10))
jax.pmap(f)(device_index, data)
# ShardedDeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
# [6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
# [7., 7., 7., 7., 7., 7., 7., 7., 7., 7.]], dtype=float32)
【讨论】: