Sync callbacks

This commit is contained in:
Antonin Raffin 2020-03-12 12:34:25 +01:00
parent 18f38f8cf5
commit b64873ffff
11 changed files with 81 additions and 48 deletions

View file

@ -17,6 +17,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Synced callbacks with Stable-Baselines
Deprecations:
^^^^^^^^^^^^^

View file

@ -23,7 +23,7 @@ def test_cemrl():
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
def test_onpolicy(model_class, env_id):
model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model = model_class('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model.learn(total_timesteps=1000, eval_freq=500)

View file

@ -4,8 +4,7 @@ from gym import spaces
from typing import Type, Union, Callable, Optional, Dict, Any
from torchy_baselines.common import logger
from torchy_baselines.common.callbacks import BaseCallback
from torchy_baselines.common.type_aliases import GymEnv
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
from torchy_baselines.common.utils import explained_variance
from torchy_baselines.ppo.policies import PPOPolicy
from torchy_baselines.ppo.ppo import PPO
@ -154,7 +153,7 @@ class A2C(PPO):
def learn(self,
total_timesteps: int,
callback: Optional[BaseCallback] = None,
callback: MaybeCallback = None,
log_interval: int = 100,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,

View file

@ -3,8 +3,7 @@ from typing import Type, Union, Callable, Optional, Dict, Any
import torch as th
from torchy_baselines.common.base_class import OffPolicyRLModel
from torchy_baselines.common.callbacks import BaseCallback
from torchy_baselines.common.type_aliases import GymEnv
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
from torchy_baselines.common.noise import ActionNoise
from torchy_baselines.td3.td3 import TD3, TD3Policy
from torchy_baselines.cem_rl.cem import CEM
@ -121,7 +120,7 @@ class CEMRL(TD3):
def learn(self,
total_timesteps: int,
callback: Optional[BaseCallback] = None,
callback: MaybeCallback = None,
log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,

View file

@ -16,7 +16,7 @@ from torchy_baselines.common.policies import BasePolicy, get_policy_from_name
from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, RolloutReturn
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback
from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
from torchy_baselines.common.monitor import Monitor
from torchy_baselines.common.noise import ActionNoise
@ -281,7 +281,7 @@ class BaseRLModel(ABC):
@abstractmethod
def learn(self, total_timesteps: int,
callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "run",
eval_env: Optional[GymEnv] = None,
@ -877,10 +877,6 @@ class OffPolicyRLModel(BaseRLModel):
while not done:
# Only stop training if return value is False, not when it is None.
if callback() is False:
return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False)
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.actor.reset_noise()
@ -913,6 +909,10 @@ class OffPolicyRLModel(BaseRLModel):
# Rescale and perform action
new_obs, reward, done, infos = env.step(self.unscale_action(clipped_action))
# Only stop training if return value is False, not when it is None.
if callback.on_step() is False:
return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False)
episode_reward += reward
# Retrieve reward and episode length if using Monitor wrapper

View file

@ -1,12 +1,13 @@
import os
from abc import ABC, abstractmethod
import warnings
import typing
from typing import Union, List, Dict, Any, Optional
import gym
import numpy as np
from torchy_baselines.common.vec_env import VecEnv, sync_envs_normalization
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
from torchy_baselines.common.evaluation import evaluate_policy
from torchy_baselines.common.logger import Logger
@ -22,9 +23,13 @@ class BaseCallback(ABC):
"""
def __init__(self, verbose: int = 0):
super(BaseCallback, self).__init__()
# The RL model
self.model = None # type: Optional[BaseRLModel]
# An alias for self.model.get_env(), the environment used for training
self.training_env = None # type: Union[gym.Env, VecEnv, None]
# Number of time the callback was called
self.n_calls = 0 # type: int
# n_envs * n times env.step() was called
self.num_timesteps = 0 # type: int
self.verbose = verbose
self.locals = None # type: Optional[Dict[str, Any]]
@ -70,9 +75,13 @@ class BaseCallback(ABC):
"""
return True
def __call__(self) -> bool:
def on_step(self) -> bool:
"""
This method will be called by the model. This is the equivalent to the callback function.
This method will be called by the model after each call to ``env.step()``.
For child callback (of an ``EventCallback``), this will be called
when the event is triggered.
:return: (bool) If the callback returns False, training is aborted early.
"""
self.n_calls += 1
@ -128,6 +137,12 @@ class EventCallback(BaseCallback):
class CallbackList(BaseCallback):
"""
Class for chaining callbacks.
:param callbacks: (List[BaseCallback]) A list of callbacks that will be called
sequentially.
"""
def __init__(self, callbacks: List[BaseCallback]):
super(CallbackList, self).__init__()
assert isinstance(callbacks, list)
@ -141,16 +156,21 @@ class CallbackList(BaseCallback):
for callback in self.callbacks:
callback.on_training_start(self.locals, self.globals)
def _on_rollout_start(self) -> None:
for callback in self.callbacks:
callback.on_rollout_start()
def _on_step(self) -> bool:
continue_training = True
for callback in self.callbacks:
# # Update variables
# callback.num_timesteps = self.num_timesteps
# callback.n_calls = self.n_calls
# Return False (stop training) if at least one callback returns False
continue_training = callback() and continue_training
continue_training = callback.on_step() and continue_training
return continue_training
def _on_rollout_end(self) -> None:
for callback in self.callbacks:
callback.on_rollout_end()
def _on_training_end(self) -> None:
for callback in self.callbacks:
callback.on_training_end()
@ -158,7 +178,7 @@ class CallbackList(BaseCallback):
class CheckpointCallback(BaseCallback):
"""
Callback for saving a model every `save_freq` steps
Callback for saving a model every ``save_freq`` steps
:param save_freq: (int)
:param save_path: (str) Path to the folder where the model will be saved.
@ -207,16 +227,17 @@ class EvalCallback(EventCallback):
:param eval_env: (Union[gym.Env, VecEnv]) The environment used for initialization
:param callback_on_new_best: (Optional[BaseCallback]) Callback to trigger
when there is a new best model according to the `mean_reward`
when there is a new best model according to the ``mean_reward``
:param n_eval_episodes: (int) The number of episodes to test the agent
:param eval_freq: (int) Evaluate the agent every eval_freq call of the callback.
:param log_path: (str) Path to a folder where the evaluations (`evaluations.npz`)
:param log_path: (str) Path to a folder where the evaluations (``evaluations.npz``)
will be saved. It will be updated at each evaluation.
:param best_model_save_path: (str) Path to a folder where the best model
according to performance on the eval env will be saved.
:param deterministic: (bool) Whether the evaluation should
use a stochastic or deterministic actions.
:param deterministic: (bool) Whether to render or not the environment during evaluation
:param render: (bool) Whether to render or not the environment during evaluation
:param verbose: (int)
"""
def __init__(self, eval_env: Union[gym.Env, VecEnv],
@ -236,12 +257,16 @@ class EvalCallback(EventCallback):
self.deterministic = deterministic
self.render = render
# Convert to VecEnv for consistency
if not isinstance(eval_env, VecEnv):
eval_env = DummyVecEnv([lambda: eval_env])
if isinstance(eval_env, VecEnv):
assert eval_env.num_envs == 1, "You must pass only one environment for evaluation"
self.eval_env = eval_env
self.best_model_save_path = best_model_save_path
# Logs will be written in `evaluations.npz`
# Logs will be written in ``evaluations.npz``
if log_path is not None:
log_path = os.path.join(log_path, 'evaluations')
self.log_path = log_path
@ -250,9 +275,10 @@ class EvalCallback(EventCallback):
self.evaluations_length = []
def _init_callback(self):
# Does not work when eval_env is a gym.Env and training_env is a VecEnv
# assert type(self.training_env) is type(self.eval_env), ("training and eval env are not of the same type",
# "{} != {}".format(self.training_env, self.eval_env))
# Does not work in some corner cases, where the wrapper is not the same
if not type(self.training_env) is type(self.eval_env):
warnings.warn("Training and eval env are not of the same type"
f"{self.training_env} != {self.eval_env}")
# Create folders if needed
if self.best_model_save_path is not None:
@ -306,7 +332,7 @@ class StopTrainingOnRewardThreshold(BaseCallback):
Stop the training once a threshold in episodic reward
has been reached (i.e. when the model is good enough).
It must be used with the `EvalCallback`.
It must be used with the ``EvalCallback``.
:param reward_threshold: (float) Minimum expected reward per episode
to stop training.
@ -317,8 +343,8 @@ class StopTrainingOnRewardThreshold(BaseCallback):
self.reward_threshold = reward_threshold
def _on_step(self) -> bool:
assert self.parent is not None, ("`StopTrainingOnMinimumReward` callback must be used "
"with an `EvalCallback`")
assert self.parent is not None, ("``StopTrainingOnMinimumReward`` callback must be used "
"with an ``EvalCallback``")
# Convert np.bool to bool, otherwise callback() is False won't work
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
if self.verbose > 0 and not continue_training:
@ -329,7 +355,7 @@ class StopTrainingOnRewardThreshold(BaseCallback):
class EveryNTimesteps(EventCallback):
"""
Trigger a callback every `n_steps` timesteps
Trigger a callback every ``n_steps`` timesteps
:param n_steps: (int) Number of timesteps between two trigger.
:param callback: (BaseCallback) Callback that will be called

View file

@ -1,18 +1,21 @@
"""
Common aliases for type hint
"""
from typing import Union, Dict, Any, NamedTuple, Optional
import typing
from typing import Union, Dict, Any, NamedTuple, Optional, List, Callable
import numpy as np
import torch as th
import gym
from torchy_baselines.common.vec_env import VecEnv
from torchy_baselines.common.callbacks import BaseCallback
GymEnv = Union[gym.Env, VecEnv]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback]
class RolloutBufferSamples(NamedTuple):

View file

@ -1,4 +1,6 @@
# flake8: noqa F401
import typing
from typing import Optional
from copy import deepcopy
from torchy_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError,\
@ -8,8 +10,12 @@ from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack
from torchy_baselines.common.vec_env.vec_normalize import VecNormalize
# Avoid circular import
if typing.TYPE_CHECKING:
from torchy_baselines.common.type_aliases import GymEnv
def unwrap_vec_normalize(env):
def unwrap_vec_normalize(env: 'GymEnv') -> Optional[VecNormalize]:
"""
:param env: (gym.Env)
:return: (VecNormalize)
@ -23,16 +29,17 @@ def unwrap_vec_normalize(env):
# Define here to avoid circular import
def sync_envs_normalization(env, eval_env):
def sync_envs_normalization(env: 'GymEnv', eval_env: 'GymEnv') -> None:
"""
Sync eval env and train env when using VecNormalize
:param env: (gym.Env)
:param eval_env: (gym.Env)
:param env: (GymEnv)
:param eval_env: (GymEnv)
"""
env_tmp, eval_env_tmp = env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, VecNormalize):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
env_tmp = env_tmp.venv
eval_env_tmp = eval_env_tmp.venv

View file

@ -16,7 +16,7 @@ import numpy as np
from torchy_baselines.common import logger
from torchy_baselines.common.base_class import BaseRLModel
from torchy_baselines.common.type_aliases import GymEnv
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
from torchy_baselines.common.buffers import RolloutBuffer
from torchy_baselines.common.utils import explained_variance, get_schedule_fn
from torchy_baselines.common.vec_env import VecEnv
@ -163,10 +163,6 @@ class PPO(BaseRLModel):
while n_steps < n_rollout_steps:
if callback() is False:
continue_training = False
return None, continue_training
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.policy.reset_noise(env.num_envs)
@ -182,6 +178,10 @@ class PPO(BaseRLModel):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
if callback.on_step() is False:
continue_training = False
return None, continue_training
self._update_info_buffer(infos)
n_steps += 1
self.num_timesteps += env.num_envs
@ -286,7 +286,7 @@ class PPO(BaseRLModel):
def learn(self,
total_timesteps: int,
callback: Optional[BaseCallback] = None,
callback: MaybeCallback = None,
log_interval: int = 1,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,

View file

@ -7,9 +7,8 @@ import numpy as np
from torchy_baselines.common import logger
from torchy_baselines.common.base_class import OffPolicyRLModel
from torchy_baselines.common.buffers import ReplayBuffer
from torchy_baselines.common.type_aliases import GymEnv
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
from torchy_baselines.common.noise import ActionNoise
from torchy_baselines.common.callbacks import BaseCallback
from torchy_baselines.sac.policies import SACPolicy
@ -253,7 +252,7 @@ class SAC(OffPolicyRLModel):
def learn(self,
total_timesteps: int,
callback: Optional[BaseCallback] = None,
callback: MaybeCallback = None,
log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,

View file

@ -4,9 +4,8 @@ from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any
from torchy_baselines.common.base_class import OffPolicyRLModel
from torchy_baselines.common.buffers import ReplayBuffer
from torchy_baselines.common.callbacks import BaseCallback
from torchy_baselines.common.noise import ActionNoise
from torchy_baselines.common.type_aliases import ReplayBufferSamples, GymEnv
from torchy_baselines.common.type_aliases import ReplayBufferSamples, GymEnv, MaybeCallback
from torchy_baselines.td3.policies import TD3Policy
@ -264,7 +263,7 @@ class TD3(OffPolicyRLModel):
def learn(self,
total_timesteps: int,
callback: Optional[BaseCallback] = None,
callback: MaybeCallback = None,
log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,