mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Sync callbacks
This commit is contained in:
parent
18f38f8cf5
commit
b64873ffff
11 changed files with 81 additions and 48 deletions
|
|
@ -17,6 +17,7 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Synced callbacks with Stable-Baselines
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def test_cemrl():
|
|||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
|
||||
def test_onpolicy(model_class, env_id):
|
||||
model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model = model_class('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ from gym import spaces
|
|||
from typing import Type, Union, Callable, Optional, Dict, Any
|
||||
|
||||
from torchy_baselines.common import logger
|
||||
from torchy_baselines.common.callbacks import BaseCallback
|
||||
from torchy_baselines.common.type_aliases import GymEnv
|
||||
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
|
||||
from torchy_baselines.common.utils import explained_variance
|
||||
from torchy_baselines.ppo.policies import PPOPolicy
|
||||
from torchy_baselines.ppo.ppo import PPO
|
||||
|
|
@ -154,7 +153,7 @@ class A2C(PPO):
|
|||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
callback: Optional[BaseCallback] = None,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 100,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ from typing import Type, Union, Callable, Optional, Dict, Any
|
|||
import torch as th
|
||||
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.callbacks import BaseCallback
|
||||
from torchy_baselines.common.type_aliases import GymEnv
|
||||
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
|
||||
from torchy_baselines.common.noise import ActionNoise
|
||||
from torchy_baselines.td3.td3 import TD3, TD3Policy
|
||||
from torchy_baselines.cem_rl.cem import CEM
|
||||
|
|
@ -121,7 +120,7 @@ class CEMRL(TD3):
|
|||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
callback: Optional[BaseCallback] = None,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from torchy_baselines.common.policies import BasePolicy, get_policy_from_name
|
|||
from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize
|
||||
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
|
||||
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, RolloutReturn
|
||||
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback
|
||||
from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
from torchy_baselines.common.noise import ActionNoise
|
||||
|
|
@ -281,7 +281,7 @@ class BaseRLModel(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def learn(self, total_timesteps: int,
|
||||
callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 100,
|
||||
tb_log_name: str = "run",
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
|
|
@ -877,10 +877,6 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
|
||||
while not done:
|
||||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback() is False:
|
||||
return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False)
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise()
|
||||
|
|
@ -913,6 +909,10 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
# Rescale and perform action
|
||||
new_obs, reward, done, infos = env.step(self.unscale_action(clipped_action))
|
||||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback.on_step() is False:
|
||||
return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False)
|
||||
|
||||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
import warnings
|
||||
import typing
|
||||
from typing import Union, List, Dict, Any, Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.vec_env import VecEnv, sync_envs_normalization
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
|
||||
from torchy_baselines.common.evaluation import evaluate_policy
|
||||
from torchy_baselines.common.logger import Logger
|
||||
|
||||
|
|
@ -22,9 +23,13 @@ class BaseCallback(ABC):
|
|||
"""
|
||||
def __init__(self, verbose: int = 0):
|
||||
super(BaseCallback, self).__init__()
|
||||
# The RL model
|
||||
self.model = None # type: Optional[BaseRLModel]
|
||||
# An alias for self.model.get_env(), the environment used for training
|
||||
self.training_env = None # type: Union[gym.Env, VecEnv, None]
|
||||
# Number of time the callback was called
|
||||
self.n_calls = 0 # type: int
|
||||
# n_envs * n times env.step() was called
|
||||
self.num_timesteps = 0 # type: int
|
||||
self.verbose = verbose
|
||||
self.locals = None # type: Optional[Dict[str, Any]]
|
||||
|
|
@ -70,9 +75,13 @@ class BaseCallback(ABC):
|
|||
"""
|
||||
return True
|
||||
|
||||
def __call__(self) -> bool:
|
||||
def on_step(self) -> bool:
|
||||
"""
|
||||
This method will be called by the model. This is the equivalent to the callback function.
|
||||
This method will be called by the model after each call to ``env.step()``.
|
||||
|
||||
For child callback (of an ``EventCallback``), this will be called
|
||||
when the event is triggered.
|
||||
|
||||
:return: (bool) If the callback returns False, training is aborted early.
|
||||
"""
|
||||
self.n_calls += 1
|
||||
|
|
@ -128,6 +137,12 @@ class EventCallback(BaseCallback):
|
|||
|
||||
|
||||
class CallbackList(BaseCallback):
|
||||
"""
|
||||
Class for chaining callbacks.
|
||||
|
||||
:param callbacks: (List[BaseCallback]) A list of callbacks that will be called
|
||||
sequentially.
|
||||
"""
|
||||
def __init__(self, callbacks: List[BaseCallback]):
|
||||
super(CallbackList, self).__init__()
|
||||
assert isinstance(callbacks, list)
|
||||
|
|
@ -141,16 +156,21 @@ class CallbackList(BaseCallback):
|
|||
for callback in self.callbacks:
|
||||
callback.on_training_start(self.locals, self.globals)
|
||||
|
||||
def _on_rollout_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_rollout_start()
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
continue_training = True
|
||||
for callback in self.callbacks:
|
||||
# # Update variables
|
||||
# callback.num_timesteps = self.num_timesteps
|
||||
# callback.n_calls = self.n_calls
|
||||
# Return False (stop training) if at least one callback returns False
|
||||
continue_training = callback() and continue_training
|
||||
continue_training = callback.on_step() and continue_training
|
||||
return continue_training
|
||||
|
||||
def _on_rollout_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_rollout_end()
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_training_end()
|
||||
|
|
@ -158,7 +178,7 @@ class CallbackList(BaseCallback):
|
|||
|
||||
class CheckpointCallback(BaseCallback):
|
||||
"""
|
||||
Callback for saving a model every `save_freq` steps
|
||||
Callback for saving a model every ``save_freq`` steps
|
||||
|
||||
:param save_freq: (int)
|
||||
:param save_path: (str) Path to the folder where the model will be saved.
|
||||
|
|
@ -207,16 +227,17 @@ class EvalCallback(EventCallback):
|
|||
|
||||
:param eval_env: (Union[gym.Env, VecEnv]) The environment used for initialization
|
||||
:param callback_on_new_best: (Optional[BaseCallback]) Callback to trigger
|
||||
when there is a new best model according to the `mean_reward`
|
||||
when there is a new best model according to the ``mean_reward``
|
||||
:param n_eval_episodes: (int) The number of episodes to test the agent
|
||||
:param eval_freq: (int) Evaluate the agent every eval_freq call of the callback.
|
||||
:param log_path: (str) Path to a folder where the evaluations (`evaluations.npz`)
|
||||
:param log_path: (str) Path to a folder where the evaluations (``evaluations.npz``)
|
||||
will be saved. It will be updated at each evaluation.
|
||||
:param best_model_save_path: (str) Path to a folder where the best model
|
||||
according to performance on the eval env will be saved.
|
||||
:param deterministic: (bool) Whether the evaluation should
|
||||
use a stochastic or deterministic actions.
|
||||
:param deterministic: (bool) Whether to render or not the environment during evaluation
|
||||
:param render: (bool) Whether to render or not the environment during evaluation
|
||||
:param verbose: (int)
|
||||
"""
|
||||
def __init__(self, eval_env: Union[gym.Env, VecEnv],
|
||||
|
|
@ -236,12 +257,16 @@ class EvalCallback(EventCallback):
|
|||
self.deterministic = deterministic
|
||||
self.render = render
|
||||
|
||||
# Convert to VecEnv for consistency
|
||||
if not isinstance(eval_env, VecEnv):
|
||||
eval_env = DummyVecEnv([lambda: eval_env])
|
||||
|
||||
if isinstance(eval_env, VecEnv):
|
||||
assert eval_env.num_envs == 1, "You must pass only one environment for evaluation"
|
||||
|
||||
self.eval_env = eval_env
|
||||
self.best_model_save_path = best_model_save_path
|
||||
# Logs will be written in `evaluations.npz`
|
||||
# Logs will be written in ``evaluations.npz``
|
||||
if log_path is not None:
|
||||
log_path = os.path.join(log_path, 'evaluations')
|
||||
self.log_path = log_path
|
||||
|
|
@ -250,9 +275,10 @@ class EvalCallback(EventCallback):
|
|||
self.evaluations_length = []
|
||||
|
||||
def _init_callback(self):
|
||||
# Does not work when eval_env is a gym.Env and training_env is a VecEnv
|
||||
# assert type(self.training_env) is type(self.eval_env), ("training and eval env are not of the same type",
|
||||
# "{} != {}".format(self.training_env, self.eval_env))
|
||||
# Does not work in some corner cases, where the wrapper is not the same
|
||||
if not type(self.training_env) is type(self.eval_env):
|
||||
warnings.warn("Training and eval env are not of the same type"
|
||||
f"{self.training_env} != {self.eval_env}")
|
||||
|
||||
# Create folders if needed
|
||||
if self.best_model_save_path is not None:
|
||||
|
|
@ -306,7 +332,7 @@ class StopTrainingOnRewardThreshold(BaseCallback):
|
|||
Stop the training once a threshold in episodic reward
|
||||
has been reached (i.e. when the model is good enough).
|
||||
|
||||
It must be used with the `EvalCallback`.
|
||||
It must be used with the ``EvalCallback``.
|
||||
|
||||
:param reward_threshold: (float) Minimum expected reward per episode
|
||||
to stop training.
|
||||
|
|
@ -317,8 +343,8 @@ class StopTrainingOnRewardThreshold(BaseCallback):
|
|||
self.reward_threshold = reward_threshold
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
assert self.parent is not None, ("`StopTrainingOnMinimumReward` callback must be used "
|
||||
"with an `EvalCallback`")
|
||||
assert self.parent is not None, ("``StopTrainingOnMinimumReward`` callback must be used "
|
||||
"with an ``EvalCallback``")
|
||||
# Convert np.bool to bool, otherwise callback() is False won't work
|
||||
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
|
||||
if self.verbose > 0 and not continue_training:
|
||||
|
|
@ -329,7 +355,7 @@ class StopTrainingOnRewardThreshold(BaseCallback):
|
|||
|
||||
class EveryNTimesteps(EventCallback):
|
||||
"""
|
||||
Trigger a callback every `n_steps` timesteps
|
||||
Trigger a callback every ``n_steps`` timesteps
|
||||
|
||||
:param n_steps: (int) Number of timesteps between two trigger.
|
||||
:param callback: (BaseCallback) Callback that will be called
|
||||
|
|
|
|||
|
|
@ -1,18 +1,21 @@
|
|||
"""
|
||||
Common aliases for type hint
|
||||
"""
|
||||
from typing import Union, Dict, Any, NamedTuple, Optional
|
||||
import typing
|
||||
from typing import Union, Dict, Any, NamedTuple, Optional, List, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import gym
|
||||
|
||||
from torchy_baselines.common.vec_env import VecEnv
|
||||
from torchy_baselines.common.callbacks import BaseCallback
|
||||
|
||||
|
||||
GymEnv = Union[gym.Env, VecEnv]
|
||||
TensorDict = Dict[str, th.Tensor]
|
||||
OptimizerStateDict = Dict[str, Any]
|
||||
MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback]
|
||||
|
||||
|
||||
class RolloutBufferSamples(NamedTuple):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
# flake8: noqa F401
|
||||
import typing
|
||||
from typing import Optional
|
||||
from copy import deepcopy
|
||||
|
||||
from torchy_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError,\
|
||||
|
|
@ -8,8 +10,12 @@ from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
|||
from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from torchy_baselines.common.vec_env.vec_normalize import VecNormalize
|
||||
|
||||
# Avoid circular import
|
||||
if typing.TYPE_CHECKING:
|
||||
from torchy_baselines.common.type_aliases import GymEnv
|
||||
|
||||
def unwrap_vec_normalize(env):
|
||||
|
||||
def unwrap_vec_normalize(env: 'GymEnv') -> Optional[VecNormalize]:
|
||||
"""
|
||||
:param env: (gym.Env)
|
||||
:return: (VecNormalize)
|
||||
|
|
@ -23,16 +29,17 @@ def unwrap_vec_normalize(env):
|
|||
|
||||
|
||||
# Define here to avoid circular import
|
||||
def sync_envs_normalization(env, eval_env):
|
||||
def sync_envs_normalization(env: 'GymEnv', eval_env: 'GymEnv') -> None:
|
||||
"""
|
||||
Sync eval env and train env when using VecNormalize
|
||||
|
||||
:param env: (gym.Env)
|
||||
:param eval_env: (gym.Env)
|
||||
:param env: (GymEnv)
|
||||
:param eval_env: (GymEnv)
|
||||
"""
|
||||
env_tmp, eval_env_tmp = env, eval_env
|
||||
while isinstance(env_tmp, VecEnvWrapper):
|
||||
if isinstance(env_tmp, VecNormalize):
|
||||
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
|
||||
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
|
||||
env_tmp = env_tmp.venv
|
||||
eval_env_tmp = eval_env_tmp.venv
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import numpy as np
|
|||
|
||||
from torchy_baselines.common import logger
|
||||
from torchy_baselines.common.base_class import BaseRLModel
|
||||
from torchy_baselines.common.type_aliases import GymEnv
|
||||
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
|
||||
from torchy_baselines.common.buffers import RolloutBuffer
|
||||
from torchy_baselines.common.utils import explained_variance, get_schedule_fn
|
||||
from torchy_baselines.common.vec_env import VecEnv
|
||||
|
|
@ -163,10 +163,6 @@ class PPO(BaseRLModel):
|
|||
|
||||
while n_steps < n_rollout_steps:
|
||||
|
||||
if callback() is False:
|
||||
continue_training = False
|
||||
return None, continue_training
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
|
@ -182,6 +178,10 @@ class PPO(BaseRLModel):
|
|||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
|
||||
if callback.on_step() is False:
|
||||
continue_training = False
|
||||
return None, continue_training
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
self.num_timesteps += env.num_envs
|
||||
|
|
@ -286,7 +286,7 @@ class PPO(BaseRLModel):
|
|||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
callback: Optional[BaseCallback] = None,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
|
|
|
|||
|
|
@ -7,9 +7,8 @@ import numpy as np
|
|||
from torchy_baselines.common import logger
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
from torchy_baselines.common.type_aliases import GymEnv
|
||||
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
|
||||
from torchy_baselines.common.noise import ActionNoise
|
||||
from torchy_baselines.common.callbacks import BaseCallback
|
||||
from torchy_baselines.sac.policies import SACPolicy
|
||||
|
||||
|
||||
|
|
@ -253,7 +252,7 @@ class SAC(OffPolicyRLModel):
|
|||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
callback: Optional[BaseCallback] = None,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
|
|
|
|||
|
|
@ -4,9 +4,8 @@ from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any
|
|||
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
from torchy_baselines.common.callbacks import BaseCallback
|
||||
from torchy_baselines.common.noise import ActionNoise
|
||||
from torchy_baselines.common.type_aliases import ReplayBufferSamples, GymEnv
|
||||
from torchy_baselines.common.type_aliases import ReplayBufferSamples, GymEnv, MaybeCallback
|
||||
from torchy_baselines.td3.policies import TD3Policy
|
||||
|
||||
|
||||
|
|
@ -264,7 +263,7 @@ class TD3(OffPolicyRLModel):
|
|||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
callback: Optional[BaseCallback] = None,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
|
|
|
|||
Loading…
Reference in a new issue