【发布时间】:2020-08-24 21:15:12
【问题描述】:
我正在使用 OpenAI 健身房的 cartpole 环境训练强化学习模型。尽管我的权重和模型的 .h5 文件出现在目标目录中,但在运行以下代码后我得到 None - tf.train.get_checkpoint_state("C:/Users/dgt/Documents")。
这是我的全部代码 -
## Slightly modified from the following repository - https://github.com/gsurma/cartpole
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import random
import gym
import numpy as np
import tensorflow as tf
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
ENV_NAME = "CartPole-v1"
GAMMA = 0.95
LEARNING_RATE = 0.001
MEMORY_SIZE = 1000000
BATCH_SIZE = 20
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
checkpoint_path = "training_1/cp.ckpt"
class DQNSolver:
def __init__(self, observation_space, action_space):
# save_dir = args.save_dir
# self.save_dir = save_dir
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
self.exploration_rate = EXPLORATION_MAX
self.action_space = action_space
self.memory = deque(maxlen=MEMORY_SIZE)
self.model = Sequential()
self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
self.model.add(Dense(24, activation="relu"))
self.model.add(Dense(self.action_space, activation="linear"))
self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
if np.random.rand() < self.exploration_rate:
return random.randrange(self.action_space)
q_values = self.model.predict(state)
return np.argmax(q_values[0])
def experience_replay(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
for state, action, reward, state_next, terminal in batch:
q_update = reward
if not terminal:
q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
q_values = self.model.predict(state)
q_values[0][action] = q_update
self.model.fit(state, q_values, verbose=0)
self.exploration_rate *= EXPLORATION_DECAY
self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
def cartpole():
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
checkpoint = tf.train.get_checkpoint_state("C:/Users/dgt/Documents")
print('checkpoint:', checkpoint)
if checkpoint and checkpoint.model_checkpoint_path:
dqn_solver.model = keras.models.load_model('cartpole.h5')
dqn_solver.model = model.load_weights('cartpole_weights.h5')
run = 0
i = 0
while i<2:
i = i + 1
#total = 0
run += 1
state = env.reset()
state = np.reshape(state, [1, observation_space])
step = 0
while True:
step += 1
#env.render()
action = dqn_solver.act(state)
state_next, reward, terminal, info = env.step(action)
#total += reward
reward = reward if not terminal else -reward
state_next = np.reshape(state_next, [1, observation_space])
dqn_solver.remember(state, action, reward, state_next, terminal)
state = state_next
dqn_solver.model.save('cartpole.h5')
dqn_solver.model.save_weights('cartpole_weights.h5')
if terminal:
print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
#score_logger.add_score(step, run)
break
dqn_solver.experience_replay()
if __name__ == "__main__":
cartpole()
cartpole_weights.h5 和 cartpole.h5 文件都出现在我的目标目录中。但是,我相信另一个名为“检查点”的文件也应该出现。我的理解是,这就是我的代码无法运行的原因。
【问题讨论】:
标签: tensorflow tensorflow2.0 tf.keras