【发布时间】:2019-01-13 10:56:28
【问题描述】:
我正在尝试在 Keras 中保存和恢复给定模型的权重。 我成功地保存了权重,使用 model.save_weights(filepath, ...) 并且实际加载了权重。我可以通过将 model.get_weights() 保存到文件中来确认这一点,在保存和恢复之后,并比较我以这种方式收到的文件。
但是我的模型和刚开始时一样糟糕。我有什么遗漏吗?
def __init__(self, **args):
# Next, we build our model. We use the same model that was described by Mnih et al. (2015).
self.model.add(Convolution2D(32, (3, 3), strides=(1, 1)))
self.model.add(Activation('relu'))
self.model.add(Convolution2D(64, (3, 3), strides=(1, 1)))
self.model.add(Activation('relu'))
self.model.add(Convolution2D(64, (3, 3), strides=(1, 1)))
self.model.add(Activation('relu'))
self.model.add(Flatten())
self.model.add(Dense(512))
self.model.add(Activation('relu'))
self.model.add(Dense(self.nb_actions)) #nb_actions))
self.model.add(Activation('linear'))
print(self.model.summary())
if os.path.isfile("/home/abcd/model.weights"):
self.model.load_weights("/home/abcd/model.weights")
self.compile(Adam(lr=.00025), metrics=['mae'])
...
def compile(self, optimizer, metrics=[]):
metrics += [mean_q] # register default metrics
# We never train the target model, hence we can set the optimizer and loss arbitrarily.
self.target_model = clone_model(self.model, self.custom_model_objects)
if os.path.isfile("/home/abcd/target_model.weights"):
self.target_model.load_weights("/home/abcd/target_model.weights")
self.target_model.compile(optimizer='sgd', loss='mse')
self.model.compile(optimizer='sgd', loss='mse')
# Compile model.
if self.target_model_update < 1.:
# We use the `AdditionalUpdatesOptimizer` to efficiently soft-update the target model.
updates = get_soft_target_model_updates(self.target_model, self.model, self.target_model_update)
optimizer = AdditionalUpdatesOptimizer(optimizer, updates)
def clipped_masked_error(args):
y_true, y_pred, mask = args
loss = huber_loss(y_true, y_pred, self.delta_clip)
loss *= mask # apply element-wise mask
return K.sum(loss, axis=-1)
# Create trainable model. The problem is that we need to mask the output since we only
# ever want to update the Q values for a certain action. The way we achieve this is by
# using a custom Lambda layer that computes the loss. This gives us the necessary flexibility
# to mask out certain parameters by passing in multiple inputs to the Lambda layer.
y_pred = self.model.output
y_true = Input(name='y_true', shape=(self.nb_actions,))
mask = Input(name='mask', shape=(self.nb_actions,))
loss_out = Lambda(clipped_masked_error, output_shape=(1,), name='loss')([y_true, y_pred, mask])
ins = [self.model.input] if type(self.model.input) is not list else self.model.input
trainable_model = Model(inputs=ins + [y_true, mask], outputs=[loss_out, y_pred])
assert len(trainable_model.output_names) == 2
combined_metrics = {trainable_model.output_names[1]: metrics}
losses = [
lambda y_true, y_pred: y_pred, # loss is computed in Lambda layer
lambda y_true, y_pred: K.zeros_like(y_pred), # we only include this for the metrics
]
if os.path.isfile("/home/abcd/trainable_model.weights"):
trainable_model.load_weights("/home/abcd/trainable_model.weights")
trainable_model.compile(optimizer=optimizer, loss=losses, metrics=combined_metrics)
self.trainable_model = trainable_model
self.compiled = True
...
def final(self, state):
"Called at the end of each game."
# call the super-class final method
PacmanQAgent.final(self, state)
# did we finish training?
if self.episodesSoFar == self.numTraining:
# you might want to print your weights here for debugging
"*** YOUR CODE HERE ***"
self.training = False
# Save the model
self.model.save_weights("/home/abcd/model.weights", True)
self.trainable_model.save_weights("/home/abcd/trainable_model.weights", True)
self.target_model.save_weights("/home/abcd/target_model.weights", True)
【问题讨论】:
-
这不应该是这种情况,您能否确保您按照文档keras.io/getting-started/faq/… 正确加载权重,否则分享您的代码,这将有助于其他人看到问题。
-
@AnkishBansal 除了使用绝对路径和扩展名“.weights”而不是“.h5”,是的。这会是个问题吗?我在编译模型之前加载了权重,这是正确的做法吗?整个系统由 3 个模型组成,可以在这里看到 github.com/keras-rl/keras-rl/blob/master/rl/agents/dqn.py 在编译函数中
-
@Skusku 请出示相关代码,否则难以判断。
-
@a_guest 按要求添加代码
-
唯一的问题似乎是您使用的是
.weight扩展,这是我从未见过的。在这一点上我可能是错的。但是您是否尝试过使用.h5扩展名。
标签: keras