Deprecate create_eval_env, eval_env and eval_freq parameter (#1082)

* Adds deprecation warning if `eval_env` or `eval_freq` parameters are used. See #925

* added changelog entry

* added missing backtick

* deprecating `create_eval_env` parameter as well and adding comments to explain the `stacklevel` parameter used

* Updated tests to ignore DeprecationWarnings

* Updated changelog entry

* - Removed the `create_eval_env` parameter from the examples in the docs
- Removed information about the `create_eval_env` parameter from the migration docs
- Added information about deprecation of the `create_eval_env` parameter in the docs

* Add alternative in docstring

* Update docstrings

* `eval_freq` warning in docstring

* Add deprecation comments in tests

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
This commit is contained in:
tobirohrer 2022-10-10 15:39:38 +02:00 committed by GitHub
parent 7c21b79188
commit d8a430e088
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 139 additions and 91 deletions

View file

@ -532,19 +532,12 @@ linear and constant schedules.
Advanced Saving and Loading
---------------------------------
In this example, we show how to use some advanced features of Stable-Baselines3 (SB3):
how to easily create a test environment to evaluate an agent periodically,
use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
In this example, we show how to use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
By default, the replay buffer is not saved when calling ``model.save()``, in order to save space on the disk (a replay buffer can be up to several GB when using images).
However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` method to save it separately.
Stable-Baselines3 automatic creation of an environment for evaluation.
For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment while creating the agent.
Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.
.. note::
For training model after loading it, we recommend loading the replay buffer to ensure stable learning (for off-policy algorithms).
@ -562,14 +555,12 @@ Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.sac.policies import MlpPolicy
# Create the model, the training environment
# and the test environment (for evaluation)
# Create the model and the training environment
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1,
learning_rate=1e-3, create_eval_env=True)
learning_rate=1e-3)
# Evaluate the model every 1000 steps on 5 test episodes
# and save the evaluation to the "logs/" folder
model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/")
# train the model
model.learn(total_timesteps=6000)
# save the model
model.save("sac_pendulum")

View file

@ -206,8 +206,6 @@ New Features (SB3 vs SB2)
- Independent saving/loading/predict for policies
- A2C now supports Generalized Advantage Estimation (GAE) and advantage normalization (both are deactivated by default)
- Generalized State-Dependent Exploration (gSDE) exploration is available for A2C/PPO/SAC. It allows to use RL directly on real robots (cf https://arxiv.org/abs/2005.05719)
- Proper evaluation (using separate env) is included in the base class (using ``EvalCallback``),
if you pass the environment as a string, you can pass ``create_eval_env=True`` to the algorithm constructor.
- Better saving/loading: optimizers are now included in the saved parameters and there is two new methods ``save_replay_buffer`` and ``load_replay_buffer`` for the replay buffer when using off-policy algorithms (DQN/DDPG/SAC/TD3)
- You can pass ``optimizer_class`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily
customize optimizers

View file

@ -22,6 +22,7 @@ Bug Fixes:
Deprecations:
^^^^^^^^^^^^^
- Added deprecation warning if parameters ``eval_env``, ``eval_freq`` or ``create_eval_env`` are used (see #925) (@tobirohrer)
Others:
^^^^^^^
@ -1074,4 +1075,4 @@ And all the contributors:
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer

View file

@ -44,7 +44,9 @@ class A2C(OnPolicyAlgorithm):
:param normalize_advantage: Whether to normalize or not the advantage
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages

View file

@ -3,6 +3,7 @@
import io
import pathlib
import time
import warnings
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
@ -75,7 +76,9 @@ class BaseAlgorithm(ABC):
:param support_multi_env: Whether the algorithm supports training
with multiple environments (as in A2C)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: Seed for the pseudo random generators
@ -161,6 +164,15 @@ class BaseAlgorithm(ABC):
if env is not None:
if isinstance(env, str):
if create_eval_env:
warnings.warn(
"The parameter `create_eval_env` is deprecated and will be removed in the future. "
"Please use `EvalCallback` or a custom Callback instead.",
DeprecationWarning,
# By setting the `stacklevel` we refer to the initial caller of the deprecated feature.
# This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See
# https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details.
stacklevel=4,
)
self.eval_env = maybe_make_env(env, self.verbose)
env = maybe_make_env(env, self.verbose)
@ -376,6 +388,8 @@ class BaseAlgorithm(ABC):
"""
:param callback: Callback(s) called at every step with state of the algorithm.
:param eval_freq: How many steps between evaluations; if None, do not evaluate.
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param n_eval_episodes: How many episodes to play per evaluation
:param n_eval_episodes: Number of episodes to rollout during evaluation.
:param log_path: Path to a folder where the evaluations will be saved
@ -426,8 +440,12 @@ class BaseAlgorithm(ABC):
:param total_timesteps: The total number of samples (env steps) to train on
:param eval_env: Environment to use for evaluation.
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param callback: Callback(s) called at every step with state of the algorithm.
:param eval_freq: How many steps between evaluations
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
@ -435,6 +453,18 @@ class BaseAlgorithm(ABC):
:param progress_bar: Display a progress bar using tqdm and rich.
:return: Total timesteps and callback(s)
"""
if eval_env is not None or eval_freq != -1:
warnings.warn(
"Parameters `eval_env` and `eval_freq` are deprecated and will be removed in the future. "
"Please use `EvalCallback` or a custom Callback instead.",
DeprecationWarning,
# By setting the `stacklevel` we refer to the initial caller of the deprecated feature.
# This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See
# https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details.
stacklevel=4,
)
self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps:
@ -567,8 +597,11 @@ class BaseAlgorithm(ABC):
:param callback: callback(s) called at every step with state of the algorithm.
:param log_interval: The number of timesteps before logging.
:param tb_log_name: the name of the run for TensorBoard logging
:param eval_env: Environment that will be used to evaluate the agent
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
:param eval_env: Environment that will be used to evaluate the agent. Caution, this parameter
is deprecated and will be removed in the future. Please use ``EvalCallback`` instead.
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param n_eval_episodes: Number of episode to evaluate the agent
:param eval_log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)

View file

@ -62,6 +62,8 @@ class OffPolicyAlgorithm(BaseAlgorithm):
with multiple environments (as in A2C)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: Seed for the pseudo random generators

View file

@ -40,6 +40,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param policy_kwargs: additional arguments to be passed to the policy on creation

View file

@ -45,7 +45,9 @@ class DDPG(TD3):
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages

View file

@ -53,7 +53,9 @@ class DQN(OffPolicyAlgorithm):
:param max_grad_norm: The maximum value for the gradient clipping
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages

View file

@ -58,7 +58,9 @@ class PPO(OnPolicyAlgorithm):
By default, there is no limit on the kl div.
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages

View file

@ -66,7 +66,9 @@ class SAC(OffPolicyAlgorithm):
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages

View file

@ -54,7 +54,9 @@ class TD3(OffPolicyAlgorithm):
(smoothing noise)
:param target_noise_clip: Limit for absolute value of target policy smoothing noise.
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages

View file

@ -47,7 +47,10 @@ def test_auto_wrap(model_class):
env = gym.make(env_name)
eval_env = gym.make(env_name)
model = model_class("MlpPolicy", env)
model.learn(100, eval_env=eval_env)
# Catch DeprecationWarnings
with pytest.warns(DeprecationWarning): # `eval_env` is deprecated
model.learn(100, eval_env=eval_env)
@pytest.mark.parametrize("model_class", MODEL_LIST)

View file

@ -18,23 +18,25 @@ def test_deterministic_pg(model_class, action_noise):
"""
Test for DDPG and variants (TD3).
"""
model = model_class(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
create_eval_env=True,
buffer_size=250,
action_noise=action_noise,
)
model.learn(total_timesteps=300, eval_freq=250)
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
model = model_class(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
create_eval_env=True,
buffer_size=250,
action_noise=action_noise,
)
model.learn(total_timesteps=300, eval_freq=250)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_a2c(env_id):
model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model.learn(total_timesteps=1000, eval_freq=500)
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model.learn(total_timesteps=1000, eval_freq=500)
@pytest.mark.parametrize("model_class", [A2C, PPO])
@ -47,46 +49,48 @@ def test_advantage_normalization(model_class, normalize_advantage):
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2])
def test_ppo(env_id, clip_range_vf):
if clip_range_vf is not None and clip_range_vf < 0:
# Should throw an error
with pytest.raises(AssertionError):
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
if clip_range_vf is not None and clip_range_vf < 0:
# Should throw an error
with pytest.raises(AssertionError):
model = PPO(
"MlpPolicy",
env_id,
seed=0,
policy_kwargs=dict(net_arch=[16]),
verbose=1,
create_eval_env=True,
clip_range_vf=clip_range_vf,
)
else:
model = PPO(
"MlpPolicy",
env_id,
n_steps=512,
seed=0,
policy_kwargs=dict(net_arch=[16]),
verbose=1,
create_eval_env=True,
clip_range_vf=clip_range_vf,
)
else:
model = PPO(
"MlpPolicy",
env_id,
n_steps=512,
seed=0,
policy_kwargs=dict(net_arch=[16]),
verbose=1,
create_eval_env=True,
clip_range_vf=clip_range_vf,
)
model.learn(total_timesteps=1000, eval_freq=500)
model.learn(total_timesteps=1000, eval_freq=500)
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_sac(ent_coef):
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
create_eval_env=True,
buffer_size=250,
ent_coef=ent_coef,
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
)
model.learn(total_timesteps=300, eval_freq=250)
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
create_eval_env=True,
buffer_size=250,
ent_coef=ent_coef,
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
)
model.learn(total_timesteps=300, eval_freq=250)
@pytest.mark.parametrize("n_critics", [1, 3])
@ -104,17 +108,18 @@ def test_n_critics(n_critics):
def test_dqn():
model = DQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
buffer_size=500,
learning_rate=3e-4,
verbose=1,
create_eval_env=True,
)
model.learn(total_timesteps=500, eval_freq=250)
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
model = DQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
buffer_size=500,
learning_rate=3e-4,
verbose=1,
create_eval_env=True,
)
model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])

View file

@ -63,17 +63,18 @@ def test_sde_check():
@pytest.mark.parametrize("use_expln", [False, True])
def test_state_dependent_noise(model_class, use_expln):
kwargs = {"learning_starts": 0} if model_class == SAC else {"n_steps": 64}
model = model_class(
"MlpPolicy",
"Pendulum-v1",
use_sde=True,
seed=None,
create_eval_env=True,
verbose=1,
policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]),
**kwargs,
)
model.learn(total_timesteps=255, eval_freq=250)
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
model = model_class(
"MlpPolicy",
"Pendulum-v1",
use_sde=True,
seed=None,
create_eval_env=True,
verbose=1,
policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]),
**kwargs,
)
model.learn(total_timesteps=255, eval_freq=250)
model.policy.reset_noise()
if model_class == SAC:
model.policy.actor.get_std()

View file

@ -380,8 +380,8 @@ def test_offpolicy_normalization(model_class, online_sampling):
assert model.get_vec_normalize_env() is eval_env
model.learn(total_timesteps=10)
model.set_env(env)
model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75)
with pytest.warns(DeprecationWarning): # `eval_env` and `eval_freq` are deprecated
model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75)
# Check getter
assert isinstance(model.get_vec_normalize_env(), VecNormalize)