【问题标题】:How to repeat data with flow_from_directory in Keras如何在 Keras 中使用 flow_from_directory 重复数据
【发布时间】:2019-09-17 04:44:01
【问题描述】:

我正在尝试使用 keras flow_from_directory 来训练模型。但它不会重复 纪元之后的数据(即当所有数据都被迭代时)。我找不到任何 选择这样做。下面是我在训练时生成数据的代码。 例如,如果总图像 = 70 批量大小 = 32 然后在第 1 次和第 2 次迭代中给出 32 张图像,但在第三次迭代中给出 6 张图像。

# data generation from directory without labels  
trn = datagen.flow_from_directory(os.path.join(BASE, 'train_gen'),
                                         batch_size=batch_size,
                                         target_size=(inp_shape[:2]),
                                         class_mode=None)
X = trn.next() # getting a batch of data.

我希望数据生成器在数据耗尽后开始重复数据。

实际上我正在尝试训练一个 GAN,其中从 Generator-Model 生成一批图像,然后将其与一批真实图像连接起来,然后传递给 Discriminator-Model 和 GAN-Model 进行训练。我不知道如何在其中使用 fit_generator,代码如下:

def train(self, inp_shape, batch_size=1, n_epochs=1000):
    BASE = '/content/gdrive/My Drive/Dataset/GAN'

    datagen = ImageDataGenerator(rescale=1./255)
    trn_dist = datagen.flow_from_directory(os.path.join(BASE, 'train_gen'),
                                                      batch_size=batch_size,
                                                      target_size=(inp_shape[:2]),
                                                      seed = 1360000,
                                                      class_mode=None)

    val_dist = datagen.flow_from_directory(os.path.join(BASE, 'test_gen'),
                                                      batch_size=batch_size,
                                                      target_size=(inp_shape[:2]),
                                                      class_mode=None)

    trn_real = datagen.flow_from_directory(os.path.join(BASE, 'train_real'),
                                                      batch_size=batch_size,
                                                      target_size=(inp_shape[:2]),
                                                      seed = 1360000,
                                                      class_mode=None)

    for e in range(n_epochs):

      real_images = trn_real.next()

      dist_images = trn_dist.next()

      gen_images = self.generator.predict(dist_images)

      factor = inp_shape[0]/250
      gen_res = ndi.zoom(gen_images, (1, factor, factor, 1), order=2)      

      X = np.concatenate([real_images, gen_res])

      y = np.zeros(2*batch_size)
      y[:batch_size] = 1.

      self.discriminator.trainable = True
      self.discriminator.fit(X, y, batch, n_epochs)

      self.discriminator.trainable = False

      self.model.fit(gen_res, y[:batch_size])
      print ('> training --- epoch=%d/%d' % (e, n_epochs))
      if e > 0 and e % 2000 == 0:
        self.model.save('%s/models/gan_model_%d_.h5'%(BASE, e))

PS:我是 Gans 新手,如果我做错了什么请纠正我。

【问题讨论】:

    标签: tensorflow keras deep-learning data-generation


    【解决方案1】:

    要弄清楚这个问题,首先你需要知道flow_from_directory的参数。 batch_size 确定要加载以进行计算的样本数量,epoch 确定您使用 Keras 传递所有数据的次数。从本质上讲,如果您设置了 epoch=2batch_size=32,则意味着 Keras 将通过将您的数据拆分为 mini-batches 中的 32 个数据样本来检查您的所有数据两次。那么您的代码中缺少的本质上是 epoch 参数。 我建议也设置 steps_per_epoch 和 validation_data。 steps_per_epoch 确定每个 epoch 中的批次数,而不是访问每个 epoch 中的所有样本,设置 steps_per_epoch 如下。

    model.fit_generator(train_generator, steps_per_epoch=train_generator.samples/train_generator.batch_size, epochs=10, validation_data=validation_generator, validation_steps=validation_generator.samples/validation_generator.batch_size)
    

    【讨论】:

    • 感谢您的澄清,我正在尝试训练 GAN,其中必须将一批图像传递给模型。所以我对 n_epochs 使用循环,而不是使用 fit_generator。请找到更新后的代码。
    【解决方案2】:

    flow_from_directory 方法与fit_generator 函数一起使用。 fit_generator 函数允许您指定 epoch 的数量。

    model.fit_generator(trn, epochs=epochs)
    

    其中model 指的是您要训练的模型对象。应该能解决你的问题。这些函数在 Keras 文档中有很好的解释

    【讨论】:

    • 我正在尝试训练 GAN,其中来自两个来源的输入被连接并调整大小,然后传递给模型,因此我一次只能处理批处理。那么就没有办法在flow_from_directory中设置repeat模式了?
    【解决方案3】:

    您始终可以在 fit_generator 方法中指定 steps_per_epoch 参数。这将使您能够在 steps_per_epoch > total_samples // batch_size 时重复数据。

    【讨论】:

      【解决方案4】:

      我发现了一个技巧,可以让第二个生成器使用较少的图像来“重置”它的索引,从而输出 32 个图像而不是前面提到的 6 个图像。

      查看您的代码,我认为 trn_real 是具有更多图像的生成器,trn_dist 是具有较少图像的生成器。在每次迭代时,比较批形状,如果它们不相等(意味着生成器到达索引的末尾,因此输出较小的图像),然后按如下方式重置生成器:

      real_images = trn_real.next()
      dist_images = trn_dist.next()
      if real_images.shape != dist_images.shape:
          trn_dist.reset() # reset the generator with lesser images
          dist_images = trn_dist.next()
      

      【讨论】:

        猜你喜欢
        • 2020-01-01
        • 2018-03-05
        • 2018-06-30
        • 1970-01-01
        • 2018-10-30
        • 2018-08-07
        • 2021-09-30
        • 2019-12-20
        • 2019-01-09
        相关资源
        最近更新 更多