【问题标题】:Execute function specifically on CPU in Jax在 Jax 中专门在 CPU 上执行函数
【发布时间】:2022-11-23 08:01:35
【问题描述】:

我有一个函数,基本上可以实例化一个巨大的数组并做其他事情。我在 TPU 上运行我的代码,所以基本上我的内存是有限的。

如何专门在 CPU 上执行我的功能?

如果我做:

y = jax.device_put(my_function(), device=jax.devices("cpu")[0])

我猜 my_function() 首先在 TPU 上执行,结果放在 CPU 上,这给我带来了内存错误。

在我的代码开头使用 jax.config.update('jax_platform_name', 'cpu') 似乎没有效果。

另请注意,我无法修改my_function()

谢谢!

【问题讨论】:

    标签: python memory cpu tpu jax


    【解决方案1】:

    我要在这里做一个猜测。我也不能运行它所以你可能不得不摆弄它

    with jax.default_device(jax.devices("cpu")[0]):
        y = my_function()
    

    请参阅文档herehere

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-12-20
      • 2012-03-16
      • 1970-01-01
      • 1970-01-01
      • 2017-01-08
      • 1970-01-01
      • 1970-01-01
      • 2021-06-21
      相关资源
      最近更新 更多