【问题标题】:How to shuffle batches with ImageDataGenerator?如何使用 ImageDataGenerator 洗牌?
【发布时间】:2021-01-28 17:30:32
【问题描述】:

我正在使用 ImageDataGenerator 和 flow_from_dataframe 来加载数据集。

flow_from_dataframeshuffle=True 一起使用会打乱数据集中的图像。

我想洗牌。如果我有 12 张图片和batch_size=3,那么我有 4 个批次:

batch1 = [image1, image2, image3]
batch2 = [image4, image5, image6]
batch3 = [image7, image8, image9]
batch4 = [image10, image11, image12]

我想打乱批次而不打乱每批中的图像,所以我得到例如:

batch2 = [image4, image5, image6]
batch1 = [image1, image2, image3]
batch4 = [image10, image11, image12]
batch3 = [image7, image8, image9]

ImageDataGenerator 和 flow_from_dataframe 有可能吗?有我可以使用的预处理功能吗?

【问题讨论】:

  • 在model.fit()期间尝试使用参数shuffle = True
  • @sahil_angra 洗牌整个集合

标签: python tensorflow keras


【解决方案1】:

考虑使用tf.data.Dataset API。您可以在洗牌之前进行批处理操作。

import tensorflow as tf

file_names = [f'image_{i}' for i in range(1, 10)]

ds = tf.data.Dataset.from_tensor_slices(file_names).batch(3).shuffle(3)

for _ in range(3):
    for batch in ds:
        print(batch.numpy())
    print()
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
[b'image_1' b'image_2' b'image_3']

[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']

[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']

然后,您可以使用映射操作从文件名中加载图像:

def read_image(file_name):
  image = tf.io.read_file(file_name)
  image = tf.image.decode_image(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
  label = tf.strings.split(file_path, os.sep)[0]
  label = tf.cast(tf.equal(label, class_categories), tf.int32)
  return image, label

ds = ds.map(read_image)

【讨论】:

  • 我已经尝试过了,它可以工作,但你应该在map 之后执行.batch(3).shuffle(3),否则read_image 将收到与批处理大小一样多的文件名,在您的示例3中。
猜你喜欢
  • 2020-09-30
  • 1970-01-01
  • 2012-09-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2013-04-28
  • 2011-10-19
  • 1970-01-01
相关资源
最近更新 更多