【问题标题】:What does reset actually mean in Tensorflow 2 dataset?重置在 Tensorflow 2 数据集中的实际含义是什么?
【发布时间】:2019-06-01 18:46:52
【问题描述】:

我正在关注 tensorflow 2 Keras documentation。我的模型如下所示:

train_dataset = tf.data.Dataset.from_tensor_slices((np.array([_my_cus_func(i) for i in X_train]), y_train))
train_dataset = train_dataset.map(lambda vals,lab: _process_tensors(vals,lab), num_parallel_calls=4)
train_dataset = train_dataset.shuffle(buffer_size=10000)
train_dataset = train_dataset.batch(64,drop_remainder=True)
train_dataset = train_dataset.prefetch(1)
model=get_compiled_model()
model.fit(train_dataset, epochs=100)

文档说

请注意,Dataset 在每个 epoch 结束时都会重置,因此可以 重用下一个纪元。

如果您只想对特定数量的批次进行训练 这个数据集,你可以传递 steps_per_epoch 参数,它 指定模型应该使用它运行多少个训练步骤 进入下一个 epoch 之前的数据集。

如果你这样做,数据集不会在每个 epoch 结束时重置, 相反,我们只是继续绘制下一批。数据集将 最终用完数据(除非它是一个无限循环 数据集)。

重置实际上意味着什么? tensorflow 会在每个 epoch 之后从张量切片中读取数据吗?还是只重新洗牌并运行map 函数?我希望 tensorflow 在 epoch 之后从 numpy 读取数据并运行 _my_cus_func。我宁愿在dataset map or apply api 上传递_my_cus_func,但我更愿意在python 列表或numpy 数组上这样做。

【问题讨论】:

    标签: python tensorflow keras tensorflow-datasets tensorflow2.0


    【解决方案1】:

    在这种情况下,重置意味着从头开始迭代数据集。在您的特定情况下,代码缺少 repeat() 函数。所以,如果你像这样指定steps_per_epoch 参数

    model.fit(train_dataset, steps_per_epoch=N, epochs=100)
    

    它将在数据集上迭代 N 步,如果 N 小于实际示例数,它将终止训练。如果 N 较大,它将完成一个 epoch,但在用完 data 时仍会终止。如果添加重复,

    train_dataset = train_dataset.shuffle(buffer_size=10000).repeat()
    

    当达到实际示例数时,它将在数据集上开始新的循环,而不是在新纪元开始时。

    【讨论】:

    • 谢谢。理想情况下,N 应该是 sample size/batch size 对吧?
    • 确切地说,样本总数/批量大小
    猜你喜欢
    • 2023-04-01
    • 1970-01-01
    • 2021-10-07
    • 2013-01-04
    • 2015-02-21
    • 2013-06-13
    • 2015-01-20
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多