【问题标题】:What is going on tf.train.shuffle_batch()tf.train.shuffle_batch() 发生了什么
【发布时间】:2018-10-15 17:06:58
【问题描述】:

我对 tensorflow 框架非常陌生,我尝试使用这段代码来阅读和探索 CIFAR-10 数据集。

import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt

sess=tf.Session()

batch_size = 128
output_every = 50
generations = 20000
eval_every = 500
image_height = 32
image_width = 32
crop_height = 24
crop_width = 24
num_channels = 3
num_targets = 10
data_dir="CIFAR10"


image_vec_length = image_height * image_width * num_channels
record_length = 1 + image_vec_length

def read_cifar_files(filename_queue, distort_images = True):
   reader = tf.FixedLengthRecordReader(record_bytes=record_length*10)
   key, record_string = reader.read(filename_queue)
   record_bytes = tf.decode_raw(record_string, tf.uint8)

# Extract label
   image_label = tf.cast(tf.slice(record_bytes, [image_vec_length-1],[1]),tf.int32)

# Extract image
   sliced=tf.slice(record_bytes, [0],[image_vec_length])
   image_extracted = tf.reshape(sliced, [num_channels, image_height,image_width])

# Reshape image
   image_uint8image = tf.transpose(image_extracted, [1, 2, 0])
   reshaped_image = tf.cast(image_uint8image, tf.float32)

# Randomly Crop image
   final_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, crop_width, crop_height)
   if distort_images:

# Randomly flip the image horizontally, change the brightness and contrast
     final_image = tf.image.random_flip_left_right(final_image)
     final_image = tf.image.random_brightness(final_image,max_delta=63)
     final_image = tf.image.random_contrast(final_image,lower=0.2, upper=1.8)

# standardization
     final_image = tf.image.per_image_standardization(final_image)
     return  final_image, image_label

当我在没有 tf.train.shuffle_batch() 的情况下运行以下 input_pipeline() 函数时,它会给我一个形状为 (24,24,3) 的图像张量。

def input_pipeline(batch_size, train_logical=True):
    files=[os.path.join(data_dir,"data_batch_{}.bin".format(i)) for i in range(1,6)]
    filename_queue = tf.train.string_input_producer(files)
    image,label = read_cifar_files(filename_queue)
    return(image,label)


example_batch,label_batch=input_pipeline(batch_size)
threads = tf.train.start_queue_runners(sess=sess)
img,label=sess.run([example_batch, label_batch])

#output=(24,24,3) 
print(img.shape) 

但是当我使用 tf.train.shuffle_batch() 函数运行相同的 input_pipeline() 函数时,它会给出包含 128 个形状为 (128、24、24、3) 的图像的图像张量。

def input_pipeline(batch_size, train_logical=True):
    files=[os.path.join(data_dir,"data_batch_{}.bin".format(i)) for i in range(1,6)]
    filename_queue = tf.train.string_input_producer(files)
    image,label = read_cifar_files(filename_queue)

    min_after_dequeue = 1000
    capacity = min_after_dequeue + 3 * batch_size
    example_batch, label_batch = tf.train.shuffle_batch([image,label], batch_size, capacity, min_after_dequeue)
    return(example_batch, label_batch)

这怎么可能。似乎 tf.train.shuffle_batch() 从 read_cifar_files() 获取单个图像张量并返回具有 128 个图像的张量。那么 tf.train.shuffle_batch() 函数发生了什么。

【问题讨论】:

    标签: python-3.x tensorflow


    【解决方案1】:

    在 Tensorflow 中,张量只是图的一个节点。 tf.train.shuffle_batch() 函数将 2 个节点作为输入,这些节点通过图形连接到数据。

    所以它不会将“单个图像”作为输入,而是一个能够加载图像的图形。然后它向图中添加一个新操作,该操作将执行 n = batch_size 时间输入图,打乱批次并返回大小为 [bach_size, input_shape] 的输出张量。

    然后,当您在会话中运行该函数时,数据将根据图表加载,这意味着您每次调用 tf.train.shuffle_batch() 时,都会读取磁盘上的 n = batch_size 图像。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2016-02-28
      • 2014-02-28
      • 2010-10-02
      • 2020-03-14
      • 2011-04-09
      • 1970-01-01
      • 1970-01-01
      • 2014-05-09
      相关资源
      最近更新 更多