【发布时间】:2022-11-23 08:01:35
【问题描述】:
我有一个函数,基本上可以实例化一个巨大的数组并做其他事情。我在 TPU 上运行我的代码,所以基本上我的内存是有限的。
如何专门在 CPU 上执行我的功能?
如果我做:
y = jax.device_put(my_function(), device=jax.devices("cpu")[0])
我猜 my_function() 首先在 TPU 上执行,结果放在 CPU 上,这给我带来了内存错误。
在我的代码开头使用 jax.config.update('jax_platform_name', 'cpu') 似乎没有效果。
另请注意,我无法修改my_function()
谢谢!
【问题讨论】: