【问题标题】:How does the distorted_inputs() function in the TensorFlow CIFAR-10 example tutorial get 128 images per batch?TensorFlow CIFAR-10 示例教程中的 distorted_inputs() 函数如何获得每批 128 张图像?
【发布时间】:2016-04-07 23:00:16
【问题描述】:

我在 TensorFlow getting started guide for CNN 浏览 CIFAR-10 示例

现在在 cifar10_train.py 的 train 函数中,我们得到的图像是

images,labels = cifar10.distorted_inputs()

distorted_inputs() 函数中,我们在队列中生成文件名,然后读取单个记录

 # Create a queue that produces the filenames to read.
 filename_queue = tf.train.string_input_producer(filenames)

 # Read examples from files in the filename queue.
 read_input = cifar10_input.read_cifar10(filename_queue)
 reshaped_image = tf.cast(read_input.uint8image, tf.float32)

当我添加调试代码时,read_input 变量仅包含 1 条记录,其中包含图像及其高度、宽度和标签名称。

该示例然后对读取的图像/记录应用一些失真,然后将其传递给_generate_image_and_label_batch() 函数。

然后,此函数返回一个形状为 [batch_size, 32, 32, 3] 的 4D 张量,其中 batch_size = 128

上述函数在返回批处理时使用tf.train.shuffle_batch() 函数。

我的问题是tf.train.shuffle_batch() 函数中的额外记录来自哪里?我们没有向它传递任何文件名或阅读器对象。

有人能解释一下我们是如何从 1 条记录变成 128 条记录的吗?我查看了文档但不明白。

【问题讨论】:

  • 我也有同样的问题,很高兴我找到了这个

标签: machine-learning neural-network tensorflow


【解决方案1】:

tf.train.shuffle_batch() 函数可用于生成(一个或多个)包含一批输入的张量。在内部,tf.train.shuffle_batch() 创建了一个tf.RandomShuffleQueue,它使用图像和标签张量在其上调用q.enqueue() 以将单个元素(图像-标签对)排入队列。然后它返回q.dequeue_many(batch_size) 的结果,它将batch_size 随机选择的元素(图像-标签对)连接成一批图像和一批标签。

请注意,虽然从代码中看起来read_inputfilename_queue 有函数关系,但还有一个额外的问题。简单地评估tf.train.shuffle_batch() 的结果将永远阻塞,因为没有元素被添加到内部队列中。为了简化这一点,当您调用 tf.train.shuffle_batch() 时,TensorFlow 会在图中的内部集合中添加一个 QueueRunner。稍后对tf.train.start_queue_runners()(例如here in cifar10_train.py)的调用将启动一个线程,将元素添加到队列中,并使训练继续进行。 Threading and Queues HOWTO 提供有关其工作原理的更多信息。

【讨论】:

  • 谢谢,这解决了很多问题。因此,排队的工作方式就像您首先创建一个事情将如何进行的流程,然后您只需对线程说 GO,它们就会开始运行并获取和处理来自文件名、记录或其他任何地方的数据。我想对了吗?
  • distorted_inputs() 函数有点微妙:它返回两个符号张量(imageslabels,每次求值时取不同的值)。因此,虽然该函数只被调用一次,但如果您运行多个步骤(例如,通过调用 sess.run([images, labels]),或通过将它们传递给使用队列运行器的操作,如 tf.train.shuffle_batch()),它们将从文件中获取后续记录(s)。
  • reading from files HOWTO 有一些关于预处理工作原理的详细信息。对连接两个文件的支持有些有限,但您可以使用 tf.train.shuffle_batch_join() 函数实现所需的功能。
  • 比这更令人困惑 - 当您启动“队列运行器”时,将为每个队列运行器创建调用 sess.run(enqueue_op) 的额外线程。正是对sess.run()这些 调用(以及不是 您自己代码中的调用)导致直到tf.train.shuffle_batch() 的操作得以执行。抱歉,如果这听起来很复杂:我们正在努力寻找简化所有这些的方法!
  • 好吧,开始更有意义了,我昨晚成功设置了训练和测试! (准确率高达 98%!)。我敢肯定很难平衡性能和易用性。到目前为止,tensorflow 比任何其他 DNN 框架都高出数英里,感谢您的辛勤工作。
猜你喜欢
  • 2017-02-12
  • 2016-04-16
  • 2018-08-31
  • 1970-01-01
  • 1970-01-01
  • 2018-06-16
  • 2017-12-04
  • 2017-03-09
  • 1970-01-01
相关资源
最近更新 更多