【发布时间】:2021-04-15 11:48:02
【问题描述】:
我正在尝试使用 VGG16 制作分类模型,但在项目结束时,我在获取 Confusion Matrix 时遇到了错误。下面给出代码,
导入的包和模块有:
import os
import keras
import numpy as np
import tensorflow as tf
from keras.models import Model
import matplotlib.pyplot as plt
from keras.optimizers import Adam
from keras.applications import MobileNet
from sklearn.metrics import confusion_matrix
from keras.layers.core import Dense, Activation
from keras.metrics import categorical_crossentropy
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.mobilenet import preprocess_input
from tensorflow.keras.preprocessing import image_dataset_from_directory
注意:为了简短起见,我只是跳过了链接的数据集
下面定义VGG16:
vgg16_model = keras.applications.vgg16.VGG16()
vgg16_model.summary()
现在,定义模型:
model = Sequential()
for layer in vgg16_model.layers:
model.add(layer)
for layer in model.layers:
layer.trainable = False
model.add(Dense(2, activation='softmax'))
编译模型:
model.compile(Adam(lr=.0001), loss='categorical_crossentropy', metrics=['accuracy'])
拟合模型:
model.fit_generator(train_batches, steps_per_epoch=4, validation_data=valid_batches, validation_steps=4, epochs=10, verbose=2)
现在是混淆矩阵:
test_imgs, test_labels = next(test_batches)
plots(test_imgs, titles=test_labels)
test_labels = test_labels[:,0]
predictions = model.predict_generator(test_batches, steps=1, verbose=0)
cm = confusion_matrix(test_labels, np.round(predictions[:,0]))
下面我遇到了一个错误,请关注下面的代码,
cm_plot_labels = ['diseaseAffectedEggplant','freshEggplant']
plot_confusion_matrix(cm, cm_plot_labels, title="Confusion Matrix") // this line, I faced an error
错误如下,
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-28-43b96d543746> in <module>()
1 cm_plot_labels = ['diseaseAffectedEggplant','freshEggplant']
----> 2 plot_confusion_matrix(cm, cm_plot_labels, title="Confusion Matrix")
NameError: name 'plot_confusion_matrix' is not defined
【问题讨论】:
标签: python machine-learning deep-learning classification vgg-net