【问题标题】:TensorFlow select Labels from mnist datasetTensorFlow 从 mnist 数据集中选择标签
【发布时间】:2017-06-14 06:32:50
【问题描述】:

我正在使用 tensorflow.examples.tutorials.mnist 来训练具有 5 个隐藏层的 nn。

这是我训练神经网络的方式:

with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
    for iteration in range(len(mnist.test.labels)//batch_size):
        X_batch, y_batch = mnist.train.next_batch(batch_size)
        sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
    acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
    acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})
    print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)

我想训练神经网络只识别从 0 到 4 的数字。我将 logits 层更改为有 5 个输出。

如何过滤 TensorFlow 提供的 mnist 数据集,以便仅获取 0 到 4 之间的数字?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    有很多方法可以做到这一点。其中之一是当您提取 X_batch, y_batch = mnist.train.next_batch(batch_size) 时。在此步骤中,您的y_batch 将获得有关数字值的信息(数字值或数字的 one-hot)。

    您迭代批处理中的示例并检查该数字是否是您关心的数字。如果是,请将其添加到您的cleaned_up_batch。效率不是很高,但它会起作用。


    回复评论:

    效率不高,因为您可能需要多次过滤相同的数据。我认为这不是问题,因为 MNIST 非常小。正常的做法是只过滤一次,创建一个新数据集并编写自己的函数以从中获取下一批(实际上非​​常简单,因为您只需从数据集中随机选择 k 个元素)

    【讨论】:

    • 谢谢你,这是一条路,但你是对的,效率不高,你知道其他方法吗?我在想我也许可以将一些参数传递给 next_batch 但找不到
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2022-01-01
    • 2018-09-11
    • 2018-09-01
    • 1970-01-01
    • 2013-10-28
    • 2019-10-07
    • 1970-01-01
    相关资源
    最近更新 更多