【发布时间】:2017-09-28 04:05:00
【问题描述】:
我的 CNN 项目基于 AlexNet 模型实现 here。
我有两个主要功能,training 和 prediction,我想问你关于预测部分的指标,它们从与训练集相比不同目录中的测试集读取图像。
这是prediction 代码:
def prediction(self):
with tf.Session() as sess:
# Construct model
pred = self.alex_net_model(self.img_pl, self.weights, self.biases, self.keep_prob)
# Restore model.
ckpt = tf.train.get_checkpoint_state("ckpt_dir")
if(ckpt):
self.saver.restore(sess, MODEL_CKPT)
print "Model restored"
else:
print "No model checkpoint found to restore - ERROR"
return
### Metrics ###
y_p = tf.argmax(pred,1) # the value predicted
target_names = ['class 0', 'class 1', 'class 2']
list_pred_total = []
list_true_total = []
# Accuracy Precision Recall F1-score by TEST IMAGES
for step, elems in enumerate(self.BatchIteratorTesting(BATCH_SIZE)):
batch_imgs_test, batch_labels_test = elems
y_pred = sess.run([y_p], feed_dict={self.img_pl: batch_imgs_test, self.keep_prob: 1.0})
#print(len(y_pred))
list_pred_total.extend(y_pred)
y_true = np.argmax(batch_labels_test,1)
#print(len(y_true))
list_true_total.extend(y_true)
#### TODO: METRICS FOR PRECISION RECALL F1-SCORE ####
我的问题是:
- 我怎样才能像在
training中那样正确调用classification_report? - 为什么
y_pred是 1 个元素的列表,而y_true是一个 len 64(批量大小)的 numpy 数组?
如果这两个 len 不同,我不能做metrics.classification_report(list_true_total, list_pred_total, target_names=target_names)。
希望能解决我的疑惑。
【问题讨论】:
标签: python numpy tensorflow scikit-learn