【发布时间】:2019-10-15 14:31:49
【问题描述】:
您好,我正在为图像分类任务使用相关权重。我正在使用 Tensorflow 版本 1.14.0,我正在使用 mobilenetv1_050_224 来执行以下source 中的任务。
IMAGE_SHAPE = (400, 400)
n_classes = 10
classifier_url = 'https://tfhub.dev/google/imagenet/mobilenet_v1_050_224/classification/3'
base_model = hub.Module(classifier_url, tags=['train'])
base_model.trainable = False
classifier = tf.keras.Sequential([
hub.KerasLayer(base_model, input_shape=IMAGE_SHAPE+(3,)),
keras.layers.Dense(n_classes, activation='softmax')
])
#print (base_model.summary())
print (classifier.summary())
我训练了这个模型,并且能够使用迁移学习在我的数据集上获得良好的训练/验证准确性。以下是学习部分的代码。
train_datagen = keras.preprocessing.image.ImageDataGenerator(
rescale=1./255)
validation_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow(
x = train_dataset,
y = train_labels,
batch_size=batch_size,
seed=1)
validation_generator = validation_datagen.flow(
x = validation_dataset, # Source directory for the validation images
y = valid_labels,
batch_size=batch_size,
seed=1)
classifier.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01, beta_1=0.9, beta_2=0.999),
loss='categorical_crossentropy',
metrics=['accuracy'])
epochs = 2
steps_per_epoch = train_generator.n // batch_size
validation_steps = validation_generator.n // batch_size
model = classifier.fit_generator(train_generator,
steps_per_epoch = steps_per_epoch,
epochs=epochs,
workers=4,
validation_data=validation_generator,
validation_steps=validation_steps)
但是当我尝试保存模型时:
export_path = '/tmp/simple_keras_model.h5'
classifier.save(export_path, save_format='h5')
我收到以下错误:
NotImplementedError:只能为
hub.KerasLayer(handle, ...)使用字符串handle。得到
type(handle):
我被它困住了,无法绕开它。这方面的任何线索都会有所帮助。谢谢你。
【问题讨论】:
-
你能不能把路径中的\tmp\部分去掉,用
simple_keras_model.h5试试保存。
标签: python tensorflow