【问题标题】:Model and Weights do not load from checkpoint模型和权重不从检查点加载
【发布时间】:2020-08-24 21:15:12
【问题描述】:

我正在使用 OpenAI 健身房的 cartpole 环境训练强化学习模型。尽管我的权重和模型的 .h5 文件出现在目标目录中,但在运行以下代码后我得到 None - tf.train.get_checkpoint_state("C:/Users/dgt/Documents")。

这是我的全部代码 -

## Slightly modified from the following repository - https://github.com/gsurma/cartpole

from __future__ import absolute_import, division, print_function, unicode_literals

import os
import random
import gym
import numpy as np
import tensorflow as tf

from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint


ENV_NAME = "CartPole-v1"

GAMMA = 0.95
LEARNING_RATE = 0.001

MEMORY_SIZE = 1000000
BATCH_SIZE = 20

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995

checkpoint_path = "training_1/cp.ckpt"


class DQNSolver:

    def __init__(self, observation_space, action_space):
        # save_dir = args.save_dir
        # self.save_dir = save_dir
        # if not os.path.exists(save_dir):
        #     os.makedirs(save_dir)
        self.exploration_rate = EXPLORATION_MAX

        self.action_space = action_space
        self.memory = deque(maxlen=MEMORY_SIZE)

        self.model = Sequential()
        self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
        self.model.add(Dense(24, activation="relu"))
        self.model.add(Dense(self.action_space, activation="linear"))
        self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.action_space)
        q_values = self.model.predict(state)
        return np.argmax(q_values[0])

    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, state_next, terminal in batch:
            q_update = reward
            if not terminal:
                q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
            q_values = self.model.predict(state)
            q_values[0][action] = q_update
            self.model.fit(state, q_values, verbose=0)
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)


def cartpole():
    env = gym.make(ENV_NAME)
    #score_logger = ScoreLogger(ENV_NAME)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    dqn_solver = DQNSolver(observation_space, action_space)
    
    checkpoint = tf.train.get_checkpoint_state("C:/Users/dgt/Documents")
    print('checkpoint:', checkpoint)
    if checkpoint and checkpoint.model_checkpoint_path:
        dqn_solver.model = keras.models.load_model('cartpole.h5')
        dqn_solver.model = model.load_weights('cartpole_weights.h5')        
    run = 0
    i = 0
    while i<2:
        i = i + 1
        #total = 0
        run += 1
        state = env.reset()
        state = np.reshape(state, [1, observation_space])
        step = 0
        while True:
            step += 1
            #env.render()
            action = dqn_solver.act(state)
            state_next, reward, terminal, info = env.step(action)
            #total += reward
            reward = reward if not terminal else -reward
            state_next = np.reshape(state_next, [1, observation_space])
            dqn_solver.remember(state, action, reward, state_next, terminal)
            state = state_next
            dqn_solver.model.save('cartpole.h5')
            dqn_solver.model.save_weights('cartpole_weights.h5')
            if terminal:
                print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
                #score_logger.add_score(step, run)
                break
            dqn_solver.experience_replay()


if __name__ == "__main__":
    cartpole()

cartpole_weights.h5 和 cartpole.h5 文件都出现在我的目标目录中。但是,我相信另一个名为“检查点”的文件也应该出现。我的理解是,这就是我的代码无法运行的原因。

【问题讨论】:

    标签: tensorflow tensorflow2.0 tf.keras


    【解决方案1】:

    首先,如果您还没有保存权重/模型,代码将不会运行。所以我注释掉了下面几行,第一次运行脚本来生成文件。

        checkpoint = tf.train.get_checkpoint_state(".")
        print('checkpoint:', checkpoint)
        if checkpoint and checkpoint.model_checkpoint_path:
            dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
            dqn_solver.model.load_weights('cartpole_weights.h5')
    

    注意我还修改了上面的代码——之前有一些语法错误。特别是您帖子中的这一行

    dqn_solver.model = model.load_weights('cartpole_weights.h5')
    

    可能是导致问题的原因,因为 model.load_weights('file') 方法会改变模型(而不是返回模型)。

    然后我测试了模型权重是否正确保存/加载。为此,您可以这样做

    dqn_solver = DQNSolver(observation_space, action_space)
    dqn_solver.model.trainable_variables
    

    查看模型首次制作时的(随机初始化的)权重。然后您可以使用任一加载权重

    dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
    

    dqn_solver.model.load_weights('cartpole_weights.h5')
    

    然后您可以再次查看 trainable_variables 以确保它们与初始权重不同,并且它们是等效的。

    当您保存模型时,它会保存完整的架构 - 层的确切配置。当您保存权重时,它只会保存您可以使用 trainable_variables 看到的所有张量列表。 请注意,当您加载权重时,需要将其加载到权重所针对的确切架构中,否则将无法正常工作。因此,如果您更改了 DQNSolver 中的模型结构,然后尝试为旧模型加载权重,它将无法正常工作。如果你 load_model ,它会将模型重置为架构的确切状态,并设置权重。

    编辑 - 整个修改后的脚本

    ## Slightly modified from the following repository - https://github.com/gsurma/cartpole
    
    from __future__ import absolute_import, division, print_function, unicode_literals
    
    import os
    import random
    import gym
    import numpy as np
    import tensorflow as tf
    
    from collections import deque
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.callbacks import ModelCheckpoint
    
    
    ENV_NAME = "CartPole-v1"
    
    GAMMA = 0.95
    LEARNING_RATE = 0.001
    
    MEMORY_SIZE = 1000000
    BATCH_SIZE = 20
    
    EXPLORATION_MAX = 1.0
    EXPLORATION_MIN = 0.01
    EXPLORATION_DECAY = 0.995
    
    checkpoint_path = "training_1/cp.ckpt"
    
    
    class DQNSolver:
    
        def __init__(self, observation_space, action_space):
            # save_dir = args.save_dir
            # self.save_dir = save_dir
            # if not os.path.exists(save_dir):
            #     os.makedirs(save_dir)
            self.exploration_rate = EXPLORATION_MAX
    
            self.action_space = action_space
            self.memory = deque(maxlen=MEMORY_SIZE)
    
            self.model = Sequential()
            self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
            self.model.add(Dense(24, activation="relu"))
            self.model.add(Dense(self.action_space, activation="linear"))
            self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
    
        def remember(self, state, action, reward, next_state, done):
            self.memory.append((state, action, reward, next_state, done))
    
        def act(self, state):
            if np.random.rand() < self.exploration_rate:
                return random.randrange(self.action_space)
            q_values = self.model.predict(state)
            return np.argmax(q_values[0])
    
        def experience_replay(self):
            if len(self.memory) < BATCH_SIZE:
                return
            batch = random.sample(self.memory, BATCH_SIZE)
            for state, action, reward, state_next, terminal in batch:
                q_update = reward
                if not terminal:
                    q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
                q_values = self.model.predict(state)
                q_values[0][action] = q_update
                self.model.fit(state, q_values, verbose=0)
            self.exploration_rate *= EXPLORATION_DECAY
            self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
    
    
    def cartpole():
        env = gym.make(ENV_NAME)
        #score_logger = ScoreLogger(ENV_NAME)
        observation_space = env.observation_space.shape[0]
        action_space = env.action_space.n
        dqn_solver = DQNSolver(observation_space, action_space)
    
        # checkpoint = tf.train.get_checkpoint_state(".")
        # print('checkpoint:', checkpoint)
        # if checkpoint and checkpoint.model_checkpoint_path:
        #     dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
        #     dqn_solver.model.load_weights('cartpole_weights.h5')
        run = 0
        i = 0
        while i<2:
            i = i + 1
            #total = 0
            run += 1
            state = env.reset()
            state = np.reshape(state, [1, observation_space])
            step = 0
            while True:
                step += 1
                #env.render()
                action = dqn_solver.act(state)
                state_next, reward, terminal, info = env.step(action)
                #total += reward
                reward = reward if not terminal else -reward
                state_next = np.reshape(state_next, [1, observation_space])
                dqn_solver.remember(state, action, reward, state_next, terminal)
                state = state_next
                dqn_solver.model.save('cartpole.h5')
                dqn_solver.model.save_weights('cartpole_weights.h5')
                if terminal:
                    print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
                    #score_logger.add_score(step, run)
                    break
                dqn_solver.experience_replay()
    
    
    if __name__ == "__main__":
        cartpole()
    
    #%%  to load saved results
    env = gym.make(ENV_NAME)
    #score_logger = ScoreLogger(ENV_NAME)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    dqn_solver = DQNSolver(observation_space, action_space)
    
    dqn_solver.model = tf.keras.models.load_model('cartpole.h5')  # or
    dqn_solver.model.load_weights('cartpole_weights.h5')
    

    【讨论】:

    • 感谢您的详尽回答。由于某种原因,我仍然无法加载重物。你能把你的整个代码贴在这里吗?
    猜你喜欢
    • 1970-01-01
    • 2020-11-24
    • 2021-10-21
    • 2019-11-18
    • 2021-08-16
    • 2021-04-21
    • 1970-01-01
    • 2021-01-15
    • 1970-01-01
    相关资源
    最近更新 更多