【问题标题】:tensorflow 2.0, model.fit() : Your input ran out of datatensorflow 2.0, model.fit() : 你的输入用完了数据
【发布时间】:2020-02-14 15:15:24
【问题描述】:

我对 TensorFlow 和 Keras 完全陌生,我正在尝试尝试一些我在网上找到的代码。

特别是我正在使用时尚MNIST - 由 60000 个示例和 10000 个示例的测试集组成。它们中的每一个都是 28x28 灰度图像。

我正在关注本教程“https://towardsdatascience.com/building-your-first-neural-network-in-tensorflow-2-tensorflow-for-hackers-part-i-e1e2f1dfe7a0”,直到定义

history = model.fit(
train_dataset.repeat(), 
epochs=10, 
steps_per_epoch=500,
validation_data=val_dataset.repeat(), 
validation_steps=2)

只要我理解,我需要使用 train_dataset.repeat() 作为输入数据集,否则我将没有足够的训练示例使用这些值作为超参数(epochs、steps_per_epochs)。

我的问题是:如何避免必须使用 .repeat()? 我需要如何更改超参数?

为了简单起见,我在这里处理代码:

def preprocess(x,y):

    x = tf.cast(x,tf.float32) / 255.0
    y = tf.cast(y, tf.float32)

    return x,y 

def create_dataset(xs, ys, n_classes=10):

    ys = tf.one_hot(ys, depth=n_classes)

    return tf.data.Dataset.from_tensor_slices((xs, ys)).map(preprocess).shuffle(len(ys)).batch(128)


model.compile(optimizer = 'adam', loss =tf.losses.CategoricalCrossentropy(from_logits= True), metrics =['accuracy'])

history1 = model.fit(train_dataset.repeat(), 
                    epochs=10, 
                    steps_per_epoch=500,
                    validation_data=val_dataset.repeat(), 
                    validation_steps=2)

谢谢!

【问题讨论】:

    标签: python tensorflow keras neural-network


    【解决方案1】:

    如果您不想使用 .repeat(),您需要让您的模型在每个 epoch 只传递一次您的整个数据。

    为了做到这一点,您需要计算模型通过整个数据集需要多少步,计算很简单:

    steps_per_epoch = len(train_dataset) // batch_size
    

    因此,如果 train_dataset 为 60 000 个样本,batch_size 为 128,则每个 epoch 需要 468 步。

    通过这样设置此参数,您可以确保不会超出数据集的大小。

    【讨论】:

    • 我试过了!我运行: ''' history1 = model.fit(train_dataset, epochs=10, steps_per_epoch=468, validation_data=val_dataset, validation_steps=2) ''' 但它运行了 1 个 epoch,一旦它开始第二个它就给出我的错误:WARNING:tensorflow:Your input run out of data;中断训练。确保您的数据集或生成器至少可以生成 steps_per_epoch * epochs 批次(在本例中为 4680 个批次)。在构建数据集时,您可能需要使用 repeat() 函数。
    【解决方案2】:

    我遇到了同样的问题,这就是我发现的。 tf.keras.Model.fit 的文档:“如果 x 是 tf.data 数据集,并且 'steps_per_epoch' 为 None,则 epoch 将运行直到输入数据集耗尽。”

    换句话说,如果我们使用 tf.data.dataset 作为训练数据,我们不需要指定“steps_per_epoch”,tf 会计算出有多少步。同时,tf 会在下一个 epoch 开始时自动重复数据集,因此您可以指定任何 'epoch'。

    当传递无限重复的数据集(例如 dataset.repeat())时,您必须指定 steps_per_epoch 参数。

    【讨论】:

      猜你喜欢
      • 2020-12-28
      • 2020-06-19
      • 1970-01-01
      • 1970-01-01
      • 2021-11-25
      • 1970-01-01
      • 2020-09-18
      • 2021-05-19
      • 1970-01-01
      相关资源
      最近更新 更多