【发布时间】:2021-12-15 08:26:20
【问题描述】:
我对使用 JAX 训练神经网络很感兴趣。我查看了tf.data.Dataset,但它只提供 tf 张量。我寻找一种将数据集更改为 JAX numpy 数组的方法,我发现许多使用 Dataset.as_numpy_generator() 将 tf 张量转换为 numpy 数组的实现。但是我想知道这是否是一个好习惯,因为 numpy 数组存储在 CPU 内存中,这不是我想要的训练(我使用 GPU)。所以我发现的最后一个想法是通过调用jnp.array 手动重铸数组,但这并不是很优雅(我担心GPU内存中的副本)。有人对此有更好的想法吗?
快速代码说明:
import os
import jax.numpy as jnp
import tensorflow as tf
def generator():
for _ in range(2):
yield tf.random.uniform((1, ))
ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
output_shapes=tf.TensorShape([1]))
ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)
for i, batch in enumerate(ds1):
print(type(batch))
for i, batch in enumerate(ds2):
print(type(jnp.array(batch)))
# returns:
<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but not elegant
【问题讨论】:
-
欢迎来到 SO;如果下面的答案解决了您的问题,请接受 - 请参阅What should I do when someone answers my question?
标签: python tensorflow numpy-ndarray jax