基础pythonAPI概览
计算动作(Computing Actions)
获取策略状态(Accessing Policy State)
获取模型状态(Accessing Model State)
例子:预处理喂给model的观测值
例子:查询一个policy的动作分布
例子:从DQN模型中获取Q-value
参考资料
|
基础pythonAPI概览 |
python的API可以让我们构建更多RL模型以适应更多场景。常用的RLlib API有custom environments, preprocessors, or models。
这里有一个基础的使用案例:(for a more complete example, see custom_env.py)
import os os.environ["CUDA_VISIBLE_DEVICES"] = '3' import ray import ray.rllib.agents.ppo as ppo from ray.tune.logger import pretty_print ray.init() config = ppo.DEFAULT_CONFIG.copy() config["num_gpus"] = 1 config["num_workers"] = 2 config["eager"] = False trainer = ppo.PPOTrainer(config=config, env="CartPole-v0") # Can optionally call trainer.restore(path) to load a checkpoint. for i in range(1000): # Perform one iteration of training the policy with PPO result = trainer.train() print(pretty_print(result)) if i % 100 == 0: checkpoint = trainer.save() print("checkpoint saved at", checkpoint) # Also, in case you have trained a model outside of ray/RLlib and have created # an h5-file with weight values in it, e.g. # my_keras_model_trained_outside_rllib.save_weights("model.h5") # (see: https://keras.io/models/about-keras-models/) # ... you can load the h5-weights into your Trainer's Policy's ModelV2 # (tf or torch) by doing: trainer.import_model("my_weights.h5") # NOTE: In order for this to work, your (custom) model needs to implement # the `import_from_h5` method. # See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py # for detailed examples for tf- and torch trainers/models.