diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 3433640..646eddf 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -532,19 +532,12 @@ linear and constant schedules. Advanced Saving and Loading --------------------------------- -In this example, we show how to use some advanced features of Stable-Baselines3 (SB3): -how to easily create a test environment to evaluate an agent periodically, -use a policy independently from a model (and how to save it, load it) and save/load a replay buffer. +In this example, we show how to use a policy independently from a model (and how to save it, load it) and save/load a replay buffer. By default, the replay buffer is not saved when calling ``model.save()``, in order to save space on the disk (a replay buffer can be up to several GB when using images). However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` method to save it separately. -Stable-Baselines3 automatic creation of an environment for evaluation. -For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment while creating the agent. -Behind the scene, SB3 uses an :ref:`EvalCallback `. - - .. note:: For training model after loading it, we recommend loading the replay buffer to ensure stable learning (for off-policy algorithms). @@ -562,14 +555,12 @@ Behind the scene, SB3 uses an :ref:`EvalCallback `. from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.sac.policies import MlpPolicy - # Create the model, the training environment - # and the test environment (for evaluation) + # Create the model and the training environment model = SAC("MlpPolicy", "Pendulum-v1", verbose=1, - learning_rate=1e-3, create_eval_env=True) + learning_rate=1e-3) - # Evaluate the model every 1000 steps on 5 test episodes - # and save the evaluation to the "logs/" folder - model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/") + # train the model + model.learn(total_timesteps=6000) # save the model model.save("sac_pendulum") diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst index ef26870..571cf45 100644 --- a/docs/guide/migration.rst +++ b/docs/guide/migration.rst @@ -206,8 +206,6 @@ New Features (SB3 vs SB2) - Independent saving/loading/predict for policies - A2C now supports Generalized Advantage Estimation (GAE) and advantage normalization (both are deactivated by default) - Generalized State-Dependent Exploration (gSDE) exploration is available for A2C/PPO/SAC. It allows to use RL directly on real robots (cf https://arxiv.org/abs/2005.05719) -- Proper evaluation (using separate env) is included in the base class (using ``EvalCallback``), - if you pass the environment as a string, you can pass ``create_eval_env=True`` to the algorithm constructor. - Better saving/loading: optimizers are now included in the saved parameters and there is two new methods ``save_replay_buffer`` and ``load_replay_buffer`` for the replay buffer when using off-policy algorithms (DQN/DDPG/SAC/TD3) - You can pass ``optimizer_class`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily customize optimizers diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 75d8df6..c2ac341 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -22,6 +22,7 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ +- Added deprecation warning if parameters ``eval_env``, ``eval_freq`` or ``create_eval_env`` are used (see #925) (@tobirohrer) Others: ^^^^^^^ @@ -1074,4 +1075,4 @@ And all the contributors: @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 -@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde +@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 226f6bc..d59eebb 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -44,7 +44,9 @@ class A2C(OnPolicyAlgorithm): :param normalize_advantage: Whether to normalize or not the advantage :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be - used for evaluating the agent periodically. (Only available when passing string for the environment) + used for evaluating the agent periodically (Only available when passing string for the environment). + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index c265b7a..42dca11 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -3,6 +3,7 @@ import io import pathlib import time +import warnings from abc import ABC, abstractmethod from collections import deque from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union @@ -75,7 +76,9 @@ class BaseAlgorithm(ABC): :param support_multi_env: Whether the algorithm supports training with multiple environments (as in A2C) :param create_eval_env: Whether to create a second environment that will be - used for evaluating the agent periodically. (Only available when passing string for the environment) + used for evaluating the agent periodically (Only available when passing string for the environment). + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. :param seed: Seed for the pseudo random generators @@ -161,6 +164,15 @@ class BaseAlgorithm(ABC): if env is not None: if isinstance(env, str): if create_eval_env: + warnings.warn( + "The parameter `create_eval_env` is deprecated and will be removed in the future. " + "Please use `EvalCallback` or a custom Callback instead.", + DeprecationWarning, + # By setting the `stacklevel` we refer to the initial caller of the deprecated feature. + # This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See + # https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details. + stacklevel=4, + ) self.eval_env = maybe_make_env(env, self.verbose) env = maybe_make_env(env, self.verbose) @@ -376,6 +388,8 @@ class BaseAlgorithm(ABC): """ :param callback: Callback(s) called at every step with state of the algorithm. :param eval_freq: How many steps between evaluations; if None, do not evaluate. + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param n_eval_episodes: How many episodes to play per evaluation :param n_eval_episodes: Number of episodes to rollout during evaluation. :param log_path: Path to a folder where the evaluations will be saved @@ -426,8 +440,12 @@ class BaseAlgorithm(ABC): :param total_timesteps: The total number of samples (env steps) to train on :param eval_env: Environment to use for evaluation. + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param callback: Callback(s) called at every step with state of the algorithm. :param eval_freq: How many steps between evaluations + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param n_eval_episodes: How many episodes to play per evaluation :param log_path: Path to a folder where the evaluations will be saved :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute @@ -435,6 +453,18 @@ class BaseAlgorithm(ABC): :param progress_bar: Display a progress bar using tqdm and rich. :return: Total timesteps and callback(s) """ + + if eval_env is not None or eval_freq != -1: + warnings.warn( + "Parameters `eval_env` and `eval_freq` are deprecated and will be removed in the future. " + "Please use `EvalCallback` or a custom Callback instead.", + DeprecationWarning, + # By setting the `stacklevel` we refer to the initial caller of the deprecated feature. + # This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See + # https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details. + stacklevel=4, + ) + self.start_time = time.time_ns() if self.ep_info_buffer is None or reset_num_timesteps: @@ -567,8 +597,11 @@ class BaseAlgorithm(ABC): :param callback: callback(s) called at every step with state of the algorithm. :param log_interval: The number of timesteps before logging. :param tb_log_name: the name of the run for TensorBoard logging - :param eval_env: Environment that will be used to evaluate the agent - :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little) + :param eval_env: Environment that will be used to evaluate the agent. Caution, this parameter + is deprecated and will be removed in the future. Please use ``EvalCallback`` instead. + :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little). + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param n_eval_episodes: Number of episode to evaluate the agent :param eval_log_path: Path to a folder where the evaluations will be saved :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 99b5d22..d64d4e8 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -62,6 +62,8 @@ class OffPolicyAlgorithm(BaseAlgorithm): with multiple environments (as in A2C) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. :param seed: Seed for the pseudo random generators diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index a91301d..d19640c 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -40,6 +40,8 @@ class OnPolicyAlgorithm(BaseAlgorithm): :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. :param policy_kwargs: additional arguments to be passed to the policy on creation diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 627a2e6..26e0745 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -45,7 +45,9 @@ class DDPG(TD3): at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param create_eval_env: Whether to create a second environment that will be - used for evaluating the agent periodically. (Only available when passing string for the environment) + used for evaluating the agent periodically (Only available when passing string for the environment). + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 7263fdd..cc13ecf 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -53,7 +53,9 @@ class DQN(OffPolicyAlgorithm): :param max_grad_norm: The maximum value for the gradient clipping :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be - used for evaluating the agent periodically. (Only available when passing string for the environment) + used for evaluating the agent periodically (Only available when passing string for the environment). + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 81094aa..b187ec8 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -58,7 +58,9 @@ class PPO(OnPolicyAlgorithm): By default, there is no limit on the kl div. :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be - used for evaluating the agent periodically. (Only available when passing string for the environment) + used for evaluating the agent periodically (Only available when passing string for the environment). + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index abd6879..6d7ad29 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -66,7 +66,9 @@ class SAC(OffPolicyAlgorithm): :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) :param create_eval_env: Whether to create a second environment that will be - used for evaluating the agent periodically. (Only available when passing string for the environment) + used for evaluating the agent periodically (Only available when passing string for the environment). + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 87d94b3..f7dd08b 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -54,7 +54,9 @@ class TD3(OffPolicyAlgorithm): (smoothing noise) :param target_noise_clip: Limit for absolute value of target policy smoothing noise. :param create_eval_env: Whether to create a second environment that will be - used for evaluating the agent periodically. (Only available when passing string for the environment) + used for evaluating the agent periodically (Only available when passing string for the environment). + Caution, this parameter is deprecated and will be removed in the future. + Please use `EvalCallback` or a custom Callback instead. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages diff --git a/tests/test_predict.py b/tests/test_predict.py index 89cdb09..abbf254 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -47,7 +47,10 @@ def test_auto_wrap(model_class): env = gym.make(env_name) eval_env = gym.make(env_name) model = model_class("MlpPolicy", env) - model.learn(100, eval_env=eval_env) + + # Catch DeprecationWarnings + with pytest.warns(DeprecationWarning): # `eval_env` is deprecated + model.learn(100, eval_env=eval_env) @pytest.mark.parametrize("model_class", MODEL_LIST) diff --git a/tests/test_run.py b/tests/test_run.py index 655182d..66b0ff8 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -18,23 +18,25 @@ def test_deterministic_pg(model_class, action_noise): """ Test for DDPG and variants (TD3). """ - model = model_class( - "MlpPolicy", - "Pendulum-v1", - policy_kwargs=dict(net_arch=[64, 64]), - learning_starts=100, - verbose=1, - create_eval_env=True, - buffer_size=250, - action_noise=action_noise, - ) - model.learn(total_timesteps=300, eval_freq=250) + with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated + model = model_class( + "MlpPolicy", + "Pendulum-v1", + policy_kwargs=dict(net_arch=[64, 64]), + learning_starts=100, + verbose=1, + create_eval_env=True, + buffer_size=250, + action_noise=action_noise, + ) + model.learn(total_timesteps=300, eval_freq=250) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_a2c(env_id): - model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) - model.learn(total_timesteps=1000, eval_freq=500) + with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated + model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + model.learn(total_timesteps=1000, eval_freq=500) @pytest.mark.parametrize("model_class", [A2C, PPO]) @@ -47,46 +49,48 @@ def test_advantage_normalization(model_class, normalize_advantage): @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) @pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2]) def test_ppo(env_id, clip_range_vf): - if clip_range_vf is not None and clip_range_vf < 0: - # Should throw an error - with pytest.raises(AssertionError): + with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated + if clip_range_vf is not None and clip_range_vf < 0: + # Should throw an error + with pytest.raises(AssertionError): + model = PPO( + "MlpPolicy", + env_id, + seed=0, + policy_kwargs=dict(net_arch=[16]), + verbose=1, + create_eval_env=True, + clip_range_vf=clip_range_vf, + ) + else: model = PPO( "MlpPolicy", env_id, + n_steps=512, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True, clip_range_vf=clip_range_vf, ) - else: - model = PPO( - "MlpPolicy", - env_id, - n_steps=512, - seed=0, - policy_kwargs=dict(net_arch=[16]), - verbose=1, - create_eval_env=True, - clip_range_vf=clip_range_vf, - ) - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=1000, eval_freq=500) @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) def test_sac(ent_coef): - model = SAC( - "MlpPolicy", - "Pendulum-v1", - policy_kwargs=dict(net_arch=[64, 64]), - learning_starts=100, - verbose=1, - create_eval_env=True, - buffer_size=250, - ent_coef=ent_coef, - action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)), - ) - model.learn(total_timesteps=300, eval_freq=250) + with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated + model = SAC( + "MlpPolicy", + "Pendulum-v1", + policy_kwargs=dict(net_arch=[64, 64]), + learning_starts=100, + verbose=1, + create_eval_env=True, + buffer_size=250, + ent_coef=ent_coef, + action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)), + ) + model.learn(total_timesteps=300, eval_freq=250) @pytest.mark.parametrize("n_critics", [1, 3]) @@ -104,17 +108,18 @@ def test_n_critics(n_critics): def test_dqn(): - model = DQN( - "MlpPolicy", - "CartPole-v1", - policy_kwargs=dict(net_arch=[64, 64]), - learning_starts=100, - buffer_size=500, - learning_rate=3e-4, - verbose=1, - create_eval_env=True, - ) - model.learn(total_timesteps=500, eval_freq=250) + with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated + model = DQN( + "MlpPolicy", + "CartPole-v1", + policy_kwargs=dict(net_arch=[64, 64]), + learning_starts=100, + buffer_size=500, + learning_rate=3e-4, + verbose=1, + create_eval_env=True, + ) + model.learn(total_timesteps=500, eval_freq=250) @pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")]) diff --git a/tests/test_sde.py b/tests/test_sde.py index 0a650a5..4348207 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -63,17 +63,18 @@ def test_sde_check(): @pytest.mark.parametrize("use_expln", [False, True]) def test_state_dependent_noise(model_class, use_expln): kwargs = {"learning_starts": 0} if model_class == SAC else {"n_steps": 64} - model = model_class( - "MlpPolicy", - "Pendulum-v1", - use_sde=True, - seed=None, - create_eval_env=True, - verbose=1, - policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]), - **kwargs, - ) - model.learn(total_timesteps=255, eval_freq=250) + with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated + model = model_class( + "MlpPolicy", + "Pendulum-v1", + use_sde=True, + seed=None, + create_eval_env=True, + verbose=1, + policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]), + **kwargs, + ) + model.learn(total_timesteps=255, eval_freq=250) model.policy.reset_noise() if model_class == SAC: model.policy.actor.get_std() diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index a363e40..0fef682 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -380,8 +380,8 @@ def test_offpolicy_normalization(model_class, online_sampling): assert model.get_vec_normalize_env() is eval_env model.learn(total_timesteps=10) model.set_env(env) - - model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75) + with pytest.warns(DeprecationWarning): # `eval_env` and `eval_freq` are deprecated + model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75) # Check getter assert isinstance(model.get_vec_normalize_env(), VecNormalize)