【问题标题】:TensorFlow image reading queue emptyTensorFlow 图像读取队列为空
【发布时间】:2017-07-05 16:08:35
【问题描述】:

我正在尝试使用管道将图像读取到 CNN。我使用string_input_producer() 来获取文件名队列,但它似乎没有做任何事情就挂在那里。下面是我的代码,请给我一些建议如何使它工作。

def read_image_file(filename_queue, labels):
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)
    image = tf.image.decode_png(value, channels=3)
    image = tf.cast(image, tf.float32)
    resized_image = tf.image.resize_images(image, [224, 112])
    with tf.Session() as sess:
        label = getLabel(labels, key.eval())
    return resized_image, label

def input_pipeline(filename_queue, queue_names, batch_size, num_epochs, labels):
    image, label = read_image_file(filename_queue, labels)
    min_after_dequeue = 10 * batch_size
    capacity = 20 * batch_size
    image_batch, label_batch = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, num_threads=1, capacity=capacity,
        min_after_dequeue=min_after_dequeue)
    return image_batch, label_batch

train_queue = tf.train.string_input_producer(trainnames, shuffle=True, num_epochs=epochs)

train_batch, train_label = input_pipeline(train_queue, trainnames, batch_size, epochs, labels)

prediction = AlexNet(x)

#Training
with tf.name_scope("cost_function") as scope:
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=train_label, logits=prediction(train_batch)))
    tf.summary.scalar("cost_function", cost)

    train_step = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(cost)

#Accuracy
with tf.name_scope("accuracy") as scope:
    correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar("accuracy", accuracy)

    merged = tf.summary.merge_all()

#Session
with tf.Session() as sess:
    print('started')
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord, start=True)
    sess.run(threads)

    try:
        for step in range(steps_per_epch * epochs):
            print('step: %d' %step)
            sess.run(train_step)
    except tf.errors.OutOfRangeError as ex:
        pass

    coord.request_stop()
    coord.join(threads)

【问题讨论】:

    标签: machine-learning tensorflow computer-vision


    【解决方案1】:

    您的代码并非完全独立,因为未定义 get_label 方法。

    但是您遇到的问题很可能来自read_image_file 方法中的这些行:

    with tf.Session() as sess:
        label = getLabel(labels, key.eval())
    

    key.eval 部分尝试将尚未开始的队列元素出列。 在定义输入管道之前,您不应创建任何会话(也不应尝试评估 key(可能还有 labels))。 get_label 方法应该只对labelskey 执行张量操作并返回一个label 张量..

    例如,您可以使用这些tensor string operations,以便它们成为图表的一部分。

    【讨论】:

    • 我在我的代码中定义了getLabel,但没有在此处附加,它基本上是从文件名(字符串)中提取标签,但key是一个张量。所以我做了key.eval() 来获取文件名的字符串。现在好像不行了,有没有其他方法可以从张量中获取字符串?
    • 您可能需要将所有字符串操作替换为字符串张量操作,因此它们将成为图表的一部分并在运行时执行。
    • 不知什么原因无法打开tensorflow的网站。你知道我可以使用哪些操作从字符串张量中提取标签吗?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-03-24
    • 1970-01-01
    • 2015-01-03
    • 2017-12-09
    • 2012-12-08
    • 2016-04-11
    相关资源
    最近更新 更多