【发布时间】:2020-06-02 18:17:15
【问题描述】:
我正在与 Google Colab 合作并尝试使用 VGG 块训练模型。像这样:
METRICS = [
keras.metrics.TruePositives(name='tp'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.FalseNegatives(name='fn'),
keras.metrics.BinaryAccuracy(name='accuracy'),
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall'),
keras.metrics.AUC(name='auc'),
]
# function for creating a vgg block
def vgg_block(layer_in, n_filters, n_conv):
# add convolutional layers
for _ in range(n_conv):
layer_in = Conv2D(n_filters, (3,3), padding='same', activation='relu')(layer_in)
# add max pooling layer
layer_in = MaxPooling2D((2,2), strides=(2,2))(layer_in)
return layer_in
# define model input
visible = Input(shape=(256, 256, 3))
# add vgg module
layer = vgg_block(visible, 64, 2)
#####################################
flat = Flatten()(layer)
hidden1 = Dense(128, activation='relu')(flat)
output = Dense(1, activation='sigmoid')(hidden1)
model = Model(inputs=visible, outputs=output)
print(model.summary())
# plot model architecture
plot_model(model, show_shapes=True, to_file='vgg_block.png')
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=METRICS)
# New lines to obtain the best model in term of validation accuracy
from keras.callbacks import ModelCheckpoint
filepath="weights-improvement-{epoch:02d}-{val_accuracy:.2f}.h5"
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
但是,当我尝试使用 model.fit_generator 时,它给了我一个错误。我使用的代码是:
history = model.fit_generator(
train_generator,
steps_per_epoch=2000 // batch_size,
epochs=20,
validation_data=validation_generator,
validation_steps=800 // batch_size,
callbacks=callbacks_list
)
我已经尝试了所有方法,但我不知道该怎么做。它给了我以下错误:
NotFoundError: 2 root error(s) found.
(0) Not found: Resource localhost/total/N10tensorflow3VarE does not exist.
[[{{node metrics/accuracy/AssignAddVariableOp}}]]
[[metrics/precision/Mean/_87]]
(1) Not found: Resource localhost/total/N10tensorflow3VarE does not exist.
[[{{node metrics/accuracy/AssignAddVariableOp}}]]
0 successful operations.
0 derived errors ignored.
如果有任何帮助,我将不胜感激。我是新来的。我能做什么?谢谢!
【问题讨论】:
标签: python tensorflow keras google-colaboratory