【问题标题】:Shuffling tfrecords files洗牌 tfrecords 文件
【发布时间】:2018-02-08 18:27:34
【问题描述】:

我有 5 个 tfrecords 文件,每个对象一个。在训练时,我想从所有 5 个 tfrecord 中平均读取数据,即如果我的批量大小为 50,我应该从第一个 tfrecord 文件中获取 10 个样本,从第二个 tfrecord 文件中获取 10 个样本,依此类推。目前,它只是从所有三个文件中顺序读取,即我从同一记录中获得 50 个样本。有没有办法从不同的 tfrecords 文件中采样?

【问题讨论】:

    标签: tensorflow deep-learning tensorflow-datasets tfrecord


    【解决方案1】:

    我建议您阅读@mrry 在tf.data 上的tutorial。在幻灯片 42 上,他解释了如何使用 tf.data.Dataset.interleave() 同时读取多个 tfrecord 文件。

    例如,如果您有 5 个文件,其中包含:

    file0.tfrecord: [0, 1]
    file1.tfrecord: [2, 3]
    file2.tfrecord: [4, 5]
    file3.tfrecord: [6, 7]
    file4.tfrecord: [8, 9]
    

    你可以这样写数据集:

    files = ["file{}.tfrecord".format(i) for i in range(5)]
    files = tf.data.Dataset.from_tensor_slices(files)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x),
                               cycle_length=5, block_length=1)
    
    dataset = dataset.map(_parse_function)  # parse the record
    

    interleave的参数为: - cycle_length:同时读取的文件数。如果您想从所有文件中读取以创建批处理,请将其设置为文件数(在您的情况下,这是您应该做的,因为每个文件都包含一种类型的标签) - block_length:每次我们从一个文件中读取,都会从这个文件中读取block_length元素

    我们可以测试它是否按预期工作:

    iterator = dataset.make_one_shot_iterator()
    x = iterator.get_next()
    
    with tf.Session() as sess:
        for _ in range(num_samples):
            print(sess.run(x))
    

    将打印:

    0
    2
    4
    6
    8
    1
    3
    5
    7
    9
    

    【讨论】:

    • 您在此处链接的演示文稿非常酷。它的来源是什么? tf 作者还有更多这样的介绍吗?
    • @3voC:来源 Derek Murray (mrry),在 TensorFlow 团队工作并创建了 tf.data API。
    • 确保您提供给tf.data.Dataset.list_files() 的文件名模式是准确的。我提供了path/to/train//*.tfrecords,但它失败了。 (注意双斜线)
    猜你喜欢
    • 2019-06-28
    • 2021-06-14
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2013-11-26
    • 1970-01-01
    • 1970-01-01
    • 2011-07-09
    相关资源
    最近更新 更多