From b64873ffff1923b02a00c7b683099959a288ff6c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 12 Mar 2020 12:34:25 +0100 Subject: [PATCH] Sync callbacks --- docs/misc/changelog.rst | 1 + tests/test_run.py | 2 +- torchy_baselines/a2c/a2c.py | 5 +- torchy_baselines/cem_rl/cem_rl.py | 5 +- torchy_baselines/common/base_class.py | 12 ++-- torchy_baselines/common/callbacks.py | 62 +++++++++++++++------ torchy_baselines/common/type_aliases.py | 5 +- torchy_baselines/common/vec_env/__init__.py | 15 +++-- torchy_baselines/ppo/ppo.py | 12 ++-- torchy_baselines/sac/sac.py | 5 +- torchy_baselines/td3/td3.py | 5 +- 11 files changed, 81 insertions(+), 48 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3bd119b..38a63d5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Synced callbacks with Stable-Baselines Deprecations: ^^^^^^^^^^^^^ diff --git a/tests/test_run.py b/tests/test_run.py index 1a3f991..db33cc7 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -23,7 +23,7 @@ def test_cemrl(): @pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0']) def test_onpolicy(model_class, env_id): - model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + model = model_class('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=1000, eval_freq=500) diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 5f9730d..252f8bf 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -4,8 +4,7 @@ from gym import spaces from typing import Type, Union, Callable, Optional, Dict, Any from torchy_baselines.common import logger -from torchy_baselines.common.callbacks import BaseCallback -from torchy_baselines.common.type_aliases import GymEnv +from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback from torchy_baselines.common.utils import explained_variance from torchy_baselines.ppo.policies import PPOPolicy from torchy_baselines.ppo.ppo import PPO @@ -154,7 +153,7 @@ class A2C(PPO): def learn(self, total_timesteps: int, - callback: Optional[BaseCallback] = None, + callback: MaybeCallback = None, log_interval: int = 100, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py index 867f2c4..0069aef 100644 --- a/torchy_baselines/cem_rl/cem_rl.py +++ b/torchy_baselines/cem_rl/cem_rl.py @@ -3,8 +3,7 @@ from typing import Type, Union, Callable, Optional, Dict, Any import torch as th from torchy_baselines.common.base_class import OffPolicyRLModel -from torchy_baselines.common.callbacks import BaseCallback -from torchy_baselines.common.type_aliases import GymEnv +from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback from torchy_baselines.common.noise import ActionNoise from torchy_baselines.td3.td3 import TD3, TD3Policy from torchy_baselines.cem_rl.cem import CEM @@ -121,7 +120,7 @@ class CEMRL(TD3): def learn(self, total_timesteps: int, - callback: Optional[BaseCallback] = None, + callback: MaybeCallback = None, log_interval: int = 4, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index f1d205a..19ea791 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -16,7 +16,7 @@ from torchy_baselines.common.policies import BasePolicy, get_policy_from_name from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr -from torchy_baselines.common.type_aliases import GymEnv, TensorDict, RolloutReturn +from torchy_baselines.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback from torchy_baselines.common.monitor import Monitor from torchy_baselines.common.noise import ActionNoise @@ -281,7 +281,7 @@ class BaseRLModel(ABC): @abstractmethod def learn(self, total_timesteps: int, - callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None, + callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = "run", eval_env: Optional[GymEnv] = None, @@ -877,10 +877,6 @@ class OffPolicyRLModel(BaseRLModel): while not done: - # Only stop training if return value is False, not when it is None. - if callback() is False: - return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False) - if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.actor.reset_noise() @@ -913,6 +909,10 @@ class OffPolicyRLModel(BaseRLModel): # Rescale and perform action new_obs, reward, done, infos = env.step(self.unscale_action(clipped_action)) + # Only stop training if return value is False, not when it is None. + if callback.on_step() is False: + return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False) + episode_reward += reward # Retrieve reward and episode length if using Monitor wrapper diff --git a/torchy_baselines/common/callbacks.py b/torchy_baselines/common/callbacks.py index 392716c..1a8403a 100644 --- a/torchy_baselines/common/callbacks.py +++ b/torchy_baselines/common/callbacks.py @@ -1,12 +1,13 @@ import os from abc import ABC, abstractmethod +import warnings import typing from typing import Union, List, Dict, Any, Optional import gym import numpy as np -from torchy_baselines.common.vec_env import VecEnv, sync_envs_normalization +from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization from torchy_baselines.common.evaluation import evaluate_policy from torchy_baselines.common.logger import Logger @@ -22,9 +23,13 @@ class BaseCallback(ABC): """ def __init__(self, verbose: int = 0): super(BaseCallback, self).__init__() + # The RL model self.model = None # type: Optional[BaseRLModel] + # An alias for self.model.get_env(), the environment used for training self.training_env = None # type: Union[gym.Env, VecEnv, None] + # Number of time the callback was called self.n_calls = 0 # type: int + # n_envs * n times env.step() was called self.num_timesteps = 0 # type: int self.verbose = verbose self.locals = None # type: Optional[Dict[str, Any]] @@ -70,9 +75,13 @@ class BaseCallback(ABC): """ return True - def __call__(self) -> bool: + def on_step(self) -> bool: """ - This method will be called by the model. This is the equivalent to the callback function. + This method will be called by the model after each call to ``env.step()``. + + For child callback (of an ``EventCallback``), this will be called + when the event is triggered. + :return: (bool) If the callback returns False, training is aborted early. """ self.n_calls += 1 @@ -128,6 +137,12 @@ class EventCallback(BaseCallback): class CallbackList(BaseCallback): + """ + Class for chaining callbacks. + + :param callbacks: (List[BaseCallback]) A list of callbacks that will be called + sequentially. + """ def __init__(self, callbacks: List[BaseCallback]): super(CallbackList, self).__init__() assert isinstance(callbacks, list) @@ -141,16 +156,21 @@ class CallbackList(BaseCallback): for callback in self.callbacks: callback.on_training_start(self.locals, self.globals) + def _on_rollout_start(self) -> None: + for callback in self.callbacks: + callback.on_rollout_start() + def _on_step(self) -> bool: continue_training = True for callback in self.callbacks: - # # Update variables - # callback.num_timesteps = self.num_timesteps - # callback.n_calls = self.n_calls # Return False (stop training) if at least one callback returns False - continue_training = callback() and continue_training + continue_training = callback.on_step() and continue_training return continue_training + def _on_rollout_end(self) -> None: + for callback in self.callbacks: + callback.on_rollout_end() + def _on_training_end(self) -> None: for callback in self.callbacks: callback.on_training_end() @@ -158,7 +178,7 @@ class CallbackList(BaseCallback): class CheckpointCallback(BaseCallback): """ - Callback for saving a model every `save_freq` steps + Callback for saving a model every ``save_freq`` steps :param save_freq: (int) :param save_path: (str) Path to the folder where the model will be saved. @@ -207,16 +227,17 @@ class EvalCallback(EventCallback): :param eval_env: (Union[gym.Env, VecEnv]) The environment used for initialization :param callback_on_new_best: (Optional[BaseCallback]) Callback to trigger - when there is a new best model according to the `mean_reward` + when there is a new best model according to the ``mean_reward`` :param n_eval_episodes: (int) The number of episodes to test the agent :param eval_freq: (int) Evaluate the agent every eval_freq call of the callback. - :param log_path: (str) Path to a folder where the evaluations (`evaluations.npz`) + :param log_path: (str) Path to a folder where the evaluations (``evaluations.npz``) will be saved. It will be updated at each evaluation. :param best_model_save_path: (str) Path to a folder where the best model according to performance on the eval env will be saved. :param deterministic: (bool) Whether the evaluation should use a stochastic or deterministic actions. :param deterministic: (bool) Whether to render or not the environment during evaluation + :param render: (bool) Whether to render or not the environment during evaluation :param verbose: (int) """ def __init__(self, eval_env: Union[gym.Env, VecEnv], @@ -236,12 +257,16 @@ class EvalCallback(EventCallback): self.deterministic = deterministic self.render = render + # Convert to VecEnv for consistency + if not isinstance(eval_env, VecEnv): + eval_env = DummyVecEnv([lambda: eval_env]) + if isinstance(eval_env, VecEnv): assert eval_env.num_envs == 1, "You must pass only one environment for evaluation" self.eval_env = eval_env self.best_model_save_path = best_model_save_path - # Logs will be written in `evaluations.npz` + # Logs will be written in ``evaluations.npz`` if log_path is not None: log_path = os.path.join(log_path, 'evaluations') self.log_path = log_path @@ -250,9 +275,10 @@ class EvalCallback(EventCallback): self.evaluations_length = [] def _init_callback(self): - # Does not work when eval_env is a gym.Env and training_env is a VecEnv - # assert type(self.training_env) is type(self.eval_env), ("training and eval env are not of the same type", - # "{} != {}".format(self.training_env, self.eval_env)) + # Does not work in some corner cases, where the wrapper is not the same + if not type(self.training_env) is type(self.eval_env): + warnings.warn("Training and eval env are not of the same type" + f"{self.training_env} != {self.eval_env}") # Create folders if needed if self.best_model_save_path is not None: @@ -306,7 +332,7 @@ class StopTrainingOnRewardThreshold(BaseCallback): Stop the training once a threshold in episodic reward has been reached (i.e. when the model is good enough). - It must be used with the `EvalCallback`. + It must be used with the ``EvalCallback``. :param reward_threshold: (float) Minimum expected reward per episode to stop training. @@ -317,8 +343,8 @@ class StopTrainingOnRewardThreshold(BaseCallback): self.reward_threshold = reward_threshold def _on_step(self) -> bool: - assert self.parent is not None, ("`StopTrainingOnMinimumReward` callback must be used " - "with an `EvalCallback`") + assert self.parent is not None, ("``StopTrainingOnMinimumReward`` callback must be used " + "with an ``EvalCallback``") # Convert np.bool to bool, otherwise callback() is False won't work continue_training = bool(self.parent.best_mean_reward < self.reward_threshold) if self.verbose > 0 and not continue_training: @@ -329,7 +355,7 @@ class StopTrainingOnRewardThreshold(BaseCallback): class EveryNTimesteps(EventCallback): """ - Trigger a callback every `n_steps` timesteps + Trigger a callback every ``n_steps`` timesteps :param n_steps: (int) Number of timesteps between two trigger. :param callback: (BaseCallback) Callback that will be called diff --git a/torchy_baselines/common/type_aliases.py b/torchy_baselines/common/type_aliases.py index 60042be..53c152c 100644 --- a/torchy_baselines/common/type_aliases.py +++ b/torchy_baselines/common/type_aliases.py @@ -1,18 +1,21 @@ """ Common aliases for type hint """ -from typing import Union, Dict, Any, NamedTuple, Optional +import typing +from typing import Union, Dict, Any, NamedTuple, Optional, List, Callable import numpy as np import torch as th import gym from torchy_baselines.common.vec_env import VecEnv +from torchy_baselines.common.callbacks import BaseCallback GymEnv = Union[gym.Env, VecEnv] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] +MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback] class RolloutBufferSamples(NamedTuple): diff --git a/torchy_baselines/common/vec_env/__init__.py b/torchy_baselines/common/vec_env/__init__.py index 38099af..2cbb349 100644 --- a/torchy_baselines/common/vec_env/__init__.py +++ b/torchy_baselines/common/vec_env/__init__.py @@ -1,4 +1,6 @@ # flake8: noqa F401 +import typing +from typing import Optional from copy import deepcopy from torchy_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError,\ @@ -8,8 +10,12 @@ from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack from torchy_baselines.common.vec_env.vec_normalize import VecNormalize +# Avoid circular import +if typing.TYPE_CHECKING: + from torchy_baselines.common.type_aliases import GymEnv -def unwrap_vec_normalize(env): + +def unwrap_vec_normalize(env: 'GymEnv') -> Optional[VecNormalize]: """ :param env: (gym.Env) :return: (VecNormalize) @@ -23,16 +29,17 @@ def unwrap_vec_normalize(env): # Define here to avoid circular import -def sync_envs_normalization(env, eval_env): +def sync_envs_normalization(env: 'GymEnv', eval_env: 'GymEnv') -> None: """ Sync eval env and train env when using VecNormalize - :param env: (gym.Env) - :param eval_env: (gym.Env) + :param env: (GymEnv) + :param eval_env: (GymEnv) """ env_tmp, eval_env_tmp = env, eval_env while isinstance(env_tmp, VecEnvWrapper): if isinstance(env_tmp, VecNormalize): eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) + eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) env_tmp = env_tmp.venv eval_env_tmp = eval_env_tmp.venv diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 5ac1a96..7b7cd74 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -16,7 +16,7 @@ import numpy as np from torchy_baselines.common import logger from torchy_baselines.common.base_class import BaseRLModel -from torchy_baselines.common.type_aliases import GymEnv +from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback from torchy_baselines.common.buffers import RolloutBuffer from torchy_baselines.common.utils import explained_variance, get_schedule_fn from torchy_baselines.common.vec_env import VecEnv @@ -163,10 +163,6 @@ class PPO(BaseRLModel): while n_steps < n_rollout_steps: - if callback() is False: - continue_training = False - return None, continue_training - if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.policy.reset_noise(env.num_envs) @@ -182,6 +178,10 @@ class PPO(BaseRLModel): clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) + if callback.on_step() is False: + continue_training = False + return None, continue_training + self._update_info_buffer(infos) n_steps += 1 self.num_timesteps += env.num_envs @@ -286,7 +286,7 @@ class PPO(BaseRLModel): def learn(self, total_timesteps: int, - callback: Optional[BaseCallback] = None, + callback: MaybeCallback = None, log_interval: int = 1, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index d0309e8..9d517b3 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -7,9 +7,8 @@ import numpy as np from torchy_baselines.common import logger from torchy_baselines.common.base_class import OffPolicyRLModel from torchy_baselines.common.buffers import ReplayBuffer -from torchy_baselines.common.type_aliases import GymEnv +from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback from torchy_baselines.common.noise import ActionNoise -from torchy_baselines.common.callbacks import BaseCallback from torchy_baselines.sac.policies import SACPolicy @@ -253,7 +252,7 @@ class SAC(OffPolicyRLModel): def learn(self, total_timesteps: int, - callback: Optional[BaseCallback] = None, + callback: MaybeCallback = None, log_interval: int = 4, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 9de88c1..c2cf278 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -4,9 +4,8 @@ from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any from torchy_baselines.common.base_class import OffPolicyRLModel from torchy_baselines.common.buffers import ReplayBuffer -from torchy_baselines.common.callbacks import BaseCallback from torchy_baselines.common.noise import ActionNoise -from torchy_baselines.common.type_aliases import ReplayBufferSamples, GymEnv +from torchy_baselines.common.type_aliases import ReplayBufferSamples, GymEnv, MaybeCallback from torchy_baselines.td3.policies import TD3Policy @@ -264,7 +263,7 @@ class TD3(OffPolicyRLModel): def learn(self, total_timesteps: int, - callback: Optional[BaseCallback] = None, + callback: MaybeCallback = None, log_interval: int = 4, eval_env: Optional[GymEnv] = None, eval_freq: int = -1,