【问题标题】:Tensorflow - Using batching to make predictionsTensorflow - 使用批处理进行预测
【发布时间】:2017-01-12 11:39:33
【问题描述】:

我正在尝试使用经过训练的卷积神经网络进行预测,该网络对示例专家 tensorflow 教程中的示例稍作修改。我已按照https://www.tensorflow.org/versions/master/how_tos/reading_data/index.html 的说明从 CSV 文件中读取数据。

我已经训练了模型并评估了它的准确性。然后我保存了模型并将其加载到一个新的 python 脚本中以进行预测。我仍然可以使用上面链接中详述的批处理方法还是应该使用feed_dict 代替?我在网上看到的大多数教程都使用后者。

我的代码如下所示,我基本上复制了用于从训练数据中读取的代码,这些数据以行的形式存储在单个 .csv 文件中。 Conv_nn 只是一个包含专家 MNIST 教程中详细介绍的卷积神经网络的类。除了我运行图表的部分之外,大部分内容可能不是很有用。

我怀疑我严重混淆了训练和预测 - 我不确定测试图像是否正确输入到预测操作中,或者对两个数据集使用相同的批处理操作是否有效。

filename_queue =  tf.train.string_input_producer(["data/test.csv"],num_epochs=None)

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Defaults force key value and label to int, all others to float.
record_defaults = [[1]]+[[46]]+[[1.0] for i in range(436)]
# Reads in a single row from the CSV and outputs a list of scalars.
csv_list = tf.decode_csv(value, record_defaults=record_defaults)
# Packs the different columns into separate feature tensors.
location = tf.pack(csv_list[2:4])
bbox = tf.pack(csv_list[5:8])
pix_feats = tf.pack(csv_list[9:])
onehot = tf.one_hot(csv_list[1], depth=98)
keep_prob = 0.5


# Creates batches of images and labels.
image_batch, label_batch = tf.train.shuffle_batch(
    [pix_feats, onehot],
    batch_size=50,num_threads=4,capacity=50000,min_after_dequeue=10000)

# Creates a graph of variables and operation nodes.
nn = Conv_nn(x=image_batch,keep_prob=keep_prob,pixels=33*13,outputs=98)

# Launch the default graph.
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.restore(sess, 'model1.ckpt')
    print("Model restored.")

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)

    prediction=tf.argmax(nn.y_conv,1)

    pred = sess.run([prediction])

    coord.request_stop()
    coord.join(threads)

【问题讨论】:

  • 我找到了一个(可疑的)解决方案。只需将tf.train.shuffle_batch 更改为tf.train.batch 并将batch_size 设置为您要预测的数据集的大小。这给出了预测标签的1 x batch_size numpy 数组。如果有人发现此解决方案有任何问题或更好的方法,请随时发布。

标签: python machine-learning tensorflow


【解决方案1】:

这个问题很老了,但我还是要回答,因为它已经被浏览了近 1000 次。

因此,如果您的模型有 Y 标签和 X 输入,那么

prediction=tf.argmax(Y,1)
result = prediction.eval(feed_dict={X: [data]}, session=sess)

这会评估单个输入,例如单个 mnist 图像,但它可以是批处理。

【讨论】:

    猜你喜欢
    • 2018-11-23
    • 1970-01-01
    • 2020-01-07
    • 1970-01-01
    • 2018-02-15
    • 1970-01-01
    • 1970-01-01
    • 2020-09-25
    • 2021-02-13
    相关资源
    最近更新 更多