【问题标题】:Is there a CUDA threadId alike in Jax (google)?Jax(谷歌)中是否有类似的 CUDA threadId?
【发布时间】:2021-10-05 12:06:11
【问题描述】:

我正在尝试了解jax.vmap/pmap(jax:https://jax.readthedocs.io/)的行为。 CUDA有threadId让你知道哪个线程在执行代码,jax中有没有类似的概念? (jax.process_id 不是)

【问题讨论】:

    标签: jax


    【解决方案1】:

    不,JAX 中没有真正类似于 CUDA 线程 ID 的东西。 XLA 编译器在较低级别处理有关 GPU 线程分配的详细信息,我不知道有任何直接的 API 可以将此信息返回到 JAX 的 Python 运行时。

    JAX 确实提供更高级别的设备分配处理的一种情况是使用pmap;在这种情况下,如果您想要依赖于执行映射代码的设备的逻辑,您可以将一组设备 ID 显式传递给 pmapped 函数。例如,我在一个 8 设备系统上运行了以下代码:

    import jax
    import jax.numpy as jnp
    
    num_devices = jax.device_count()
    
    def f(device, data):
      return data + device
    
    device_index = jnp.arange(num_devices)
    data = jnp.zeros((num_devices, 10))
    
    jax.pmap(f)(device_index, data)
    
    # ShardedDeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    #                     [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
    #                     [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
    #                     [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
    #                     [4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
    #                     [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
    #                     [6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
    #                     [7., 7., 7., 7., 7., 7., 7., 7., 7., 7.]], dtype=float32)
    

    【讨论】:

      猜你喜欢
      • 2019-06-03
      • 1970-01-01
      • 2019-11-29
      • 1970-01-01
      • 1970-01-01
      • 2011-05-20
      • 2012-11-25
      • 2017-01-04
      • 1970-01-01
      相关资源
      最近更新 更多