【发布时间】:2018-10-15 07:18:38
【问题描述】:
我正在使用新的 tensoflow 输入管道准备我的数据集,这是我的代码:
train_data = tf.data.Dataset.from_tensor_slices(train_images)
train_labels = tf.data.Dataset.from_tensor_slices(train_labels)
train_set = tf.data.Dataset.zip((train_data,train_labels)).shuffle(500).batch(30)
valid_data = tf.data.Dataset.from_tensor_slices(valid_images)
valid_labels = tf.data.Dataset.from_tensor_slices(valid_labels)
valid_set = tf.data.Dataset.zip((valid_data,valid_labels)).shuffle(200).batch(20)
test_data = tf.data.Dataset.from_tensor_slices(test_images)
test_labels = tf.data.Dataset.from_tensor_slices(test_labels)
test_set = tf.data.Dataset.zip((test_data, test_labels)).shuffle(200).batch(20)
# create general iterator
iterator = tf.data.Iterator.from_structure(train_set.output_types, train_set.output_shapes)
next_element = iterator.get_next()
train_init_op = iterator.make_initializer(train_set)
valid_init_op = iterator.make_initializer(valid_set)
test_init_op = iterator.make_initializer(test_set)
现在我想在训练后为我的 CNN 模型的验证集创建一个混淆矩阵,这是我尝试做的:
sess.run(valid_init_op)
valid_img, valid_label = next_element
finalprediction = tf.argmax(train_predict, 1)
actualprediction = tf.argmax(valid_label, 1)
confusion_matrix = tf.confusion_matrix(labels=actualprediction,predictions=finalprediction,
num_classes=num_classes,dtype=tf.int32,name=None, weights=None)
print(sess.run(confusion_matrix, feed_dict={keep_prob: 1.0}))
以这种方式它会创建混淆矩阵,但仅适用于一批验证集。为此,我尝试收集列表中的所有验证集批次,然后使用该列表创建混淆矩阵:
val_label_list = []
sess.run(valid_init_op)
for i in range(valid_iters):
while True:
try:
elem = sess.run(next_element[1])
val_label_list.append(elem)
except tf.errors.OutOfRangeError:
print("End of append.")
break
val_label_list = np.array(val_label_list)
val_label_list = val_label_list.reshape(40,2)
现在val_label_list 包含我的验证集所有批次的标签,我可以使用它来创建混淆矩阵:
finalprediction = tf.argmax(train_predict, 1)
actualprediction = tf.argmax(val_label_list, 1)
confusion = tf.confusion_matrix(labels=actualprediction,predictions=finalprediction,
num_classes=num_classes, dtype=tf.int32,name="Confusion_Matrix")
但是现在当我想运行混淆矩阵并打印它时:
print(sess.run(confusion, feed_dict={keep_prob: 1.0}))
它给了我一个错误:
OutOfRangeError: End of sequence
[[Node: IteratorGetNext_5 = IteratorGetNext[output_shapes=[[?,10,32,32], [?,2]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_5)]]
谁能告诉我如何处理这个错误?或任何其他解决我原来问题的解决方案?
【问题讨论】:
标签: python-3.x tensorflow