【发布时间】:2021-03-29 13:46:32
【问题描述】:
我使用下面的代码在我的TensorFlow数据集中加载一堆图片,效果很好:
def load(image_file):
image = tf.io.read_file(image_file)
image = tf.image.decode_jpeg(image)
image = tf.cast(image , tf.float32)
return image
train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.map(load , num_parallel_calls=tf.data.experimental.AUTOTUNE)
我想知道如何使用类似的代码来加载一堆 CSV 文件。每个 CSV 文件的形状为 256 x 256,可以假定为灰度图像。我不知道我应该在“加载”函数中使用什么来代替“tf.image.decode_jpeg”。 非常感谢您的帮助。
【问题讨论】:
-
您可以将 csv 文件读入 numpy 数组 (riptutorial.com/numpy/example/22990/reading-csv-files),然后使用
train_dataset = tf.data.Dataset.from_tensor_slices((images, image_labels))构建数据集 -
感谢您的建议。我试图避免一次读取所有数据以避免内存过载。我的理解是,当我们使用
tf.data.Dataset.list_files(PATH+'train/*.jpg')加载数据集时,它会在训练期间以批量方式直接从磁盘加载数据,而不是先将整个数据加载到RAM上,然后再发送一批到 GPU/CPU。对吗?
标签: python csv dataset tensorflow2.0 tensorflow-datasets