【问题标题】:what is difference between batch size in data pipeline and batch size in midel.fit()?数据管道中的批量大小和 midel.fit() 中的批量大小有什么区别?
【发布时间】:2020-06-13 16:12:23
【问题描述】:

这两个是相同的batch-size,还是有不同的含义?

BATCH_SIZE=10
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.batch(BATCH_SIZE)

第二次

history = model.fit(train_ds,
  epochs=EPOCHS,                    
  validation_data=create_dataset(X_valid, y_valid_bin),
  max_queue_size=1,
  workers=1,
  batch_size=10,
  use_multiprocessing=False)

我遇到了 Ram 无法运行的问题... 训练图像示例 333000 内存 30GB 12GB 显卡 批量大小应该是多少?

Full Code Here

【问题讨论】:

  • 第一个示例创建一个生成器,一次生成一批 10 个文件名。您可以在 model.fit 方法中将其与生成器函数一起使用。第二个示例还应该创建一个包含 10 个图像的批次,并按顺序在每个批次上训练您的模型。可能是您的验证数据太大。
  • @TirthPatel 两者都需要提供
  • 我知道训练集和验证集都可以传递给 keras 模型的 fit 方法,但您可以使用自定义生成器函数对其进行自定义。我要说的是,您是否尝试过查看验证数据并查看其大小。您还可以在 keras 中批量处理您的验证数据。 This thread discusses that in detail
  • @TirthPatel 是的!验证数据集很大..
  • 我已将验证数据集限制为 300,但仍会填满整个 ram,我将添加有问题的 git repo 链接,您可以查看。

标签: tensorflow machine-learning neural-network data-science batchsize


【解决方案1】:

数据集(批量大小)

批量大小仅表示将通过您定义的管道的数据量。在 Dateset 的情况下,批量大小表示在一次迭代中将有多少数据传递给模型。例如,如果您形成一个数据生成器并将批量大小设置为 8。现在在每个迭代数据生成器上给出 8 条数据记录。

Model.fit(批量大小)

并且在model.fit中,当我们设置批量大小时,这意味着模型将在传递等于批量大小的数据记录后计算损失。如果您了解深度学习模型,它们将计算前馈的特定损失,而不是通过反向传播,它们会自我改进。现在,如果您在 model.fit 中设置批量大小 8,则 8 条数据记录将传递给模型,并根据这 8 条数据记录计算损失,然后模型会从该损失中得到改善。

示例:

现在,如果您将 dateset 批量大小设置为 4 并将 model.fit 批量大小设置为 8。现在您的 dateset 生成器必须迭代 2 次才能将 8 个图像提供给模型,而 model.fit 只执行 1 次迭代计算损失。

内存问题

你的图片尺寸是多少?尝试减少 batch_size,因为每个 epoch 的步长与 ram 无关,但批量大小是。因为如果您提供 10 个批量大小,则必须将 10 个图像加载到 ram 上进行处理,并且您的 ram 无法同时加载 10 个图像。尝试将批量大小设为 4 或 2。这可能会对您有所帮助

【讨论】:

  • 很好,我明白了,但是,实际上我正面临 RAM 无法运行的问题。所以我也尝试了 STEPS_PER_EPOCH 参数... STEPS_PER_EPOCH=total_no_img/batch_size STEPS_PER_EPOCH=50000/10 STEPS_PER_EPOCH=5000 还是一样多少适合ram?这里我只从文件中读取 50000 条记录。
  • 根据您的回答,我们需要在两个地方都给出批量大小?
  • 你的图片尺寸是多少?尝试减少 batch_size,因为每个 epoch 的步长与 ram 无关,但批量大小是。因为如果您提供 10 个批量大小,则必须将 10 个图像加载到 ram 上进行处理,并且您的 ram 无法同时加载 10 个图像。尝试将批量大小设为 4 或 2。这可能会对您有所帮助。
  • 图片尺寸(240,240,3)
  • 12 GB 内存无法加载 10 个输入图像 (240,240,3)。将批量大小减少到 4。
猜你喜欢
  • 2021-07-21
  • 2021-10-28
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2015-09-14
  • 2022-11-24
相关资源
最近更新 更多