【问题标题】:Why the labels from tensorflow's shuffle do not work?为什么 tensorflow 的 shuffle 中的标签不起作用?
【发布时间】:2017-01-10 18:17:33
【问题描述】:

每个人! 从 shuffle 我得到图像和标签,带有 conv 的图像效果很好。

来自 tfreocords 的图像和标签

...
train_images, train_labels = shuffle(train_all_images, train_all_labels)
...

但是train_labels 不能如下工作:

numpy.sum(numpy.argmax(predictions, 1) ==  train_labels)

结果总是错误的,因为它根本无法从train_labels 获取值。

一些细节如下:

train_all_images, train_all_labels = read_and_decode("train")

train_images, train_labels = shuffle(train_all_images, train_all_labels)

......一些训练模型

optimizer = tf.train.MomentumOptimizer(learning_rate,
                                       0.9).minimize(loss,
                                                     global_step=batch)
train_prediction = tf.nn.softmax(logits)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    tf.train.start_queue_runners(sess)
    print('Initialized!')

    for step in xrange(int(num_epochs * train_size) // BATCH_SIZE):
        sess.run(optimizer)
        if step % EVAL_FREQUENCY == 0:              
            l, lr, predictions = sess.run([loss, learning_rate, train_prediction])

            print('Minibatch loss: %.3f, learning rate: %.6f' % (l, lr))
            print('Minibatch error: %.1f%%' % error_rate(predictions, train_labels))
            sys.stdout.flush()

def error_rate(predictions, labels):
    return 100.0 - ( 100.0 *
    numpy.sum(numpy.argmax(predictions, 1) == labels) /
    predictions.shape[0])

【问题讨论】:

  • 您能否提供一个完整的、可重现的无效代码示例?从这两行来看,目前尚不清楚发生了什么(例如,sess.run() 你在打什么电话等)。
  • 嗨,伙计。我又修改了我的帖子。请再检查一遍好吗?关键是如何获取 shuffle 后 train_labels 的值?这很混乱!非常感谢!
  • 如何调用error_rate() 函数?是否有有助于查找问题的错误消息?
  • error_rate() 仅用于计算预测错误率

标签: python numpy tensorflow


【解决方案1】:

原因是你必须使用 tensoflow 方法而不是 numpy,下面的 accaray 效果很好。

def correct_rate(out, labels):
  arg = tf.argmax(out, 1)
  arg = tf.cast(arg, tf.int32)
  correct_prediction = tf.equal(labels, arg)
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  return accuracy

accr = correct_rate(logits, train_labels)
print(sess.run(accr))

【讨论】:

    猜你喜欢
    • 2021-01-29
    • 2020-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-03-25
    • 2017-12-06
    • 2018-03-22
    • 2018-05-05
    相关资源
    最近更新 更多