【发布时间】:2017-03-30 06:02:46
【问题描述】:
我正在尝试基于来自 Tensorflow 的 CIFAR10 示例创建一个火车操作,该示例使用 tf.RandomShuffleQueue,我的标签来自 (Accessing filename from file queue in Tensor Flow) 中提到的文件名。我该如何使用此代码?
当我尝试运行以下代码时,path 是一个包含许多文件的目录:
filenames = [path, f) for f in os.listdir(path)][1:]
file_fifo = tf.train.string_input_producer(filenames,
shuffle=False,
capacity=len(filenames))
reader = tf.WholeFileReader()
key, value = reader.read(file_fifo)
image = tf.image.decode_png(value, channels=3, dtype=tf.uint8)
image.set_shape([config.image_height, config.image_width, config.image_depth])
image = tf.cast(image, tf.float32)
image = tf.divide(image, 255.0)
labels = [int(os.path.basename(f).split('_')[-1].split('.')[0]) for f in filenames]
label_fifo = tf.FIFOQueue(len(filenames), tf.int32, shapes=[[]])
label_enqueue = label_fifo.enqueue_many([tf.constant(labels)])
label = label_fifo.dequeue()
bq = tf.RandomShuffleQueue(capacity=16 * batch_size,
min_after_dequeue=8 * batch,
dtypes=[tf.float32, tf.int32])
batch_enqueue_op = bq.enqueue([image, label_enqueue])
runner = tf.train.queue_runner.QueueRunner(bq, [batch_enqueue_op] * num_threads)
tf.train.add_queue_runner(runner)
# Read 'batch' labels + images from the example queue.
images, labels = batch_queue.dequeue_many(FLAGS.batch_size)
labels = tf.reshape(labels, [FLAGS.batch_size, 1])
我得到了明显的错误,因为我知道我的代码没有多大意义。我有两个 FIFO 队列 @987654325@ 和 label_fifo,但我不知道如何使我的 label_fifo 输入 tf.RandomShuffleQueue。
有人可以帮忙吗?谢谢你:-)
【问题讨论】:
标签: tensorflow