【问题标题】:How to get batch size back from a tensorflow dataset?如何从张量流数据集中获取批量大小?
【发布时间】:2018-09-29 10:44:17
【问题描述】:

推荐使用tensorflow数据集作为输入管道,可设置如下:

# Specify dataset
dataset  = tf.data.Dataset.from_tensor_slices((features, labels))
# Suffle
dataset  = dataset.shuffle(buffer_size=1e5)
# Specify batch size
dataset  = dataset.batch(128)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
# Get next batch
next_batch = iterator.get_next()

我应该能够获得批量大小(来自数据集本身或从它创建的迭代器,即iteratornext_batch)。也许有人想知道数据集或其迭代器中有多少批次。或者有多少批次已经被调用,还有多少批次留在迭代器中?可能还想一次性获取特定元素,甚至整个数据集。

我在 tensorflow 文档中找不到任何内容。这可能吗?如果没有,有谁知道这是否已作为 tensorflow GitHub 上的问题提出要求?

【问题讨论】:

    标签: tensorflow issue-tracking tensorflow-datasets


    【解决方案1】:

    试试这个

    import tensorflow as tf
    import numpy as np
    
    features=np.array([[3.0, 0.0], [1.0, 2.0], [0.0, 0.0]], dtype="float32")
    labels=np.array([[0], [0], [1]], dtype="float32")
    dataset = tf.data.Dataset.from_tensor_slices((features, labels))
    
    batch_size = 2
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    batch_data = iterator.get_next()
    with tf.Session() as sess:
        sess.run(iterator.initializer)
        print(np.shape(sess.run(batch_data)[0])[0])
    你会看到

    【讨论】:

      【解决方案2】:

      至少在 TF2 中,数据集的类型是静态定义的,可通过 tf.data.Dataset.element_spec 访问。

      这是一个有点复杂的返回类型,因为它具有匹配您的数据集的元组嵌套。

      >>> tf.data.Dataset.from_tensor_slices([[[1]],[[2]]]).element_spec.shape
      TensorShape([1, 1])
      

      如果您的数据被组织为一个元组[图像,标签],那么您将获得一个 TensorSpecs 元组。如果您确定返回类型的嵌套,您可以对其进行索引。例如

      >>> image = tf.data.Dataset.from_tensor_slices([[1],[2],[3],[4]]).batch(2, drop_remainder=True)
      >>> label = tf.data.Dataset.from_tensor_slices([[1],[2],[3],[4]]).batch(2, drop_remainder=True)
      >>> train = tf.data.Dataset.zip((image, label))
      >>> train.element_spec[0].shape[0]
      2
      

      【讨论】:

        【解决方案3】:

        在 TF2 中,tf.data.Datasets 是 iterables,所以你可以通过简单的操作得到一个批次:

        batch = next(iter(dataset))
        

        然后计算批量大小是微不足道的,因为它变成了first dimension 的大小:

        batch_size = batch.shape[0]
        

        所以一个完整的例子应该是这样的:

        # Specify dataset
        dataset  = tf.data.Dataset.from_tensor_slices((features, labels))
        # Suffle
        dataset  = dataset.shuffle(buffer_size=1e5)
        # Specify batch size
        dataset  = dataset.batch(128)
        # Calculate and print batch size
        batch_size = next(iter(dataset)).shape[0]
        print('Batch size:', batch_size) # prints 128
        

        或者,如果你需要它作为一个函数:

        def calculate_batch_size(dataset):
            return next(iter(dataset)).shape[0]
        

        请注意,iterating 在数据集上需要急切执行。此外,此解决方案假定您的数据集是批处理的,如果不是这样,可能会出错。如果在批处理后对数据集执行其他操作来改变其元素的形状,您也可能会遇到错误。

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 2019-04-24
          • 1970-01-01
          • 2021-12-10
          • 1970-01-01
          • 2017-11-02
          • 1970-01-01
          • 1970-01-01
          • 2020-05-10
          相关资源
          最近更新 更多