【问题标题】:Tensorflow: How to find the size of a tf.data.Dataset API objectTensorflow:如何查找 tf.data.Dataset API 对象的大小
【发布时间】:2018-11-27 21:14:57
【问题描述】:

我了解 Dataset API 是一种迭代器,它不会将整个数据集加载到内存中,因此无法找到数据集的大小。我说的是存储在文本文件或 tfRecord 文件中的大型数据语料库。这些文件通常使用tf.data.TextLineDataset 或类似的东西读取。查找使用 tf.data.Dataset.from_tensor_slices 加载的数据集的大小很简单。

我询问数据集大小的原因如下: 假设我的数据集大小是 1000 个元素。批量大小 = 50 个元素。然后训练步骤/批次(假设 1 个 epoch)= 20。在这 20 个步骤中,我想以指数方式将我的学习率从 0.1 衰减到 0.01

tf.train.exponential_decay(
    learning_rate = 0.1,
    global_step = global_step,
    decay_steps = 20,
    decay_rate = 0.1,
    staircase=False,
    name=None
)

在上面的代码中,我有“和”想设置decay_steps = number of steps/batches per epoch = num_elements/batch_size。只有事先知道数据集中元素的数量,才能计算出这一点。

提前知道大小的另一个原因是使用tf.data.Dataset.take()tf.data.Dataset.skip() 方法将数据分成训练集和测试集。

PS:我不是在寻找暴力方法,例如遍历整个数据集并更新计数器以计算元素数量或 putting a very large batch size and then finding the size of the resultant dataset 等。

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    您可以手动指定数据集的大小吗?

    我如何加载我的数据:

    sample_id_hldr = tf.placeholder(dtype=tf.int64, shape=(None,), name="samples")
    
    sample_ids = tf.Variable(sample_id_hldr, validate_shape=False, name="samples_cache")
    num_samples = tf.size(sample_ids)
    
    data = tf.data.Dataset.from_tensor_slices(sample_ids)
    # "load" data by id:
    # return (id, data) for each id
    data = data.map(
        lambda id: (id, some_load_op(id))
    )
    

    在这里,您可以通过使用占位符初始化一次 sample_ids 来指定所有示例 ID。
    您的样本 ID 可能是例如文件路径或简单数字 (np.arange(num_elems))

    然后可以在num_samples 中获得元素的数量。

    【讨论】:

    • 感谢您的回答。但是,您似乎没有得到我的问题,对此感到抱歉。我将再次修改我的问题。我没有使用from_tensor_slices,它在您有小数据集时使用。在这种情况下,找到数据集的大小是微不足道的。我的问题是关于读取存储在文本文件中的大量数据,这些数据只能使用tf.data.TextLineDataset() 读取。在这种情况下,如何求出整个数据集的大小?
    【解决方案2】:

    您可以使用以下方法轻松获取数据样本数:

    dataset.__len__()
    

    你可以像这样获取每个元素:

    for step, element in enumerate(dataset.as_numpy_iterator()):
    ...     print(step, element)
    

    你也可以得到一个样本的形状:

    dataset.element_spec
    

    如果你想获取特定的元素,你也可以使用 shard 方法。

    【讨论】:

    • dataset.__len__() 不适用于使用tf.data.TextLineDataset 创建的数据集。错误是TypeError: dataset length is unknown.
    • 我发现 TextLineDataset 读取每个文件并将该文件的每一行作为样本,所以你可以做的是 len(dataset.as_numpy_iterator())。但是,每行的字符数会有所不同,然后您可以使用一些固定的大小使其可训练。
    【解决方案3】:

    我知道这个问题已经有两年了,但也许这个答案会有用。

    如果您使用tf.data.TextLineDataset 读取数据,那么获取样本数量的一种方法可能是计算您正在使用的所有文本文件中的行数。

    考虑以下示例:

    import random
    import string
    import tensorflow as tf
    
    filenames = ["data0.txt", "data1.txt", "data2.txt"]
    
    # Generate synthetic data.
    for filename in filenames:
        with open(filename, "w") as f:
            lines = [random.choice(string.ascii_letters) for _ in range(random.randint(10, 100))]
            print("\n".join(lines), file=f)
    
    dataset = tf.data.TextLineDataset(filenames)
    

    尝试使用len 获取长度会引发TypeError

    len(dataset)
    

    但是可以相对快速地计算出文件中的行数。

    # https://stackoverflow.com/q/845058/5666087
    def get_n_lines(filepath):
        i = -1
        with open(filepath) as f:
            for i, _ in enumerate(f):
                pass
        return i + 1
    
    n_lines = sum(get_n_lines(f) for f in filenames)
    

    在上面,n_lines 等于迭代数据集时找到的元素数

    for i, _ in enumerate(dataset):
        pass
    n_lines == i + 1
    

    【讨论】:

      【解决方案4】:

      这是我为解决问题所做的,将以下行添加到您的数据集tf.data.experimental.assert_cardinality(len_of_data) 这将解决问题,

      ast = Audioset(df) # the generator class
      db = tf.data.Dataset.from_generator(ast, output_types=(tf.float32, tf.float32, tf.int32))
      db = db.apply(tf.data.experimental.assert_cardinality(len(ast))) # number of samples
      db = db.batch(batch_size)
      

      数据集 len 根据 batch_size 改变,要获得数据集 len 只需运行 len(db)。 查看here了解更多详情

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2020-07-25
        • 1970-01-01
        • 2013-05-28
        • 2015-07-05
        • 1970-01-01
        • 1970-01-01
        • 2019-09-30
        • 2018-06-26
        相关资源
        最近更新 更多