【发布时间】:2018-01-19 21:56:19
【问题描述】:
我有一个用make_one_shot_iterator() 创建的tf.data.Iterator,并想用它来训练我的(现有)模型。
目前我的训练是这样的
input_node = tf.placeholder(tf.float32, shape=(None, height, width, channels))
net = models.ResNet50UpProj({'data': input_node}, batch_size, keep_prob=True,is_training=True)
labels = tf.placeholder(tf.float32, shape=(None, width, height, 1))
huberloss = tf.losses.huber_loss(predictions=net.get_output(),labels=labels)
然后调用
sess.run(train_op, feed_dict={labels:output_img, input_node:input_img})
训练后我可以得到这样的预测:
pred = sess.run(net.get_output(), feed_dict={input_node: img})
现在我用迭代器尝试了类似的方法
next_element = iterator.get_next()
像这样传递输入数据:
net = models.ResNet50UpProj({'data': next_element[0]}, batch_size, keep_prob=True,is_training=True)
这样定义损失函数:
huberloss = tf.losses.huber_loss(predictions=net.get_output(),labels=next_element[1])
并像在每次调用 this 时自动迭代迭代器一样简单地执行训练:
sess.run(train_op)
我的问题是:训练后我无法做出任何预测。或者更确切地说,我不知道在我的情况下使用迭代器的正确方法。
【问题讨论】:
标签: python tensorflow tensorflow-datasets