mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-27 22:55:17 +00:00
Deprecate create_eval_env, eval_env and eval_freq parameter (#1082)
* Adds deprecation warning if `eval_env` or `eval_freq` parameters are used. See #925 * added changelog entry * added missing backtick * deprecating `create_eval_env` parameter as well and adding comments to explain the `stacklevel` parameter used * Updated tests to ignore DeprecationWarnings * Updated changelog entry * - Removed the `create_eval_env` parameter from the examples in the docs - Removed information about the `create_eval_env` parameter from the migration docs - Added information about deprecation of the `create_eval_env` parameter in the docs * Add alternative in docstring * Update docstrings * `eval_freq` warning in docstring * Add deprecation comments in tests Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
This commit is contained in:
parent
7c21b79188
commit
d8a430e088
16 changed files with 139 additions and 91 deletions
|
|
@ -532,19 +532,12 @@ linear and constant schedules.
|
|||
Advanced Saving and Loading
|
||||
---------------------------------
|
||||
|
||||
In this example, we show how to use some advanced features of Stable-Baselines3 (SB3):
|
||||
how to easily create a test environment to evaluate an agent periodically,
|
||||
use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
|
||||
In this example, we show how to use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
|
||||
|
||||
By default, the replay buffer is not saved when calling ``model.save()``, in order to save space on the disk (a replay buffer can be up to several GB when using images).
|
||||
However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` method to save it separately.
|
||||
|
||||
|
||||
Stable-Baselines3 automatic creation of an environment for evaluation.
|
||||
For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment while creating the agent.
|
||||
Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
For training model after loading it, we recommend loading the replay buffer to ensure stable learning (for off-policy algorithms).
|
||||
|
|
@ -562,14 +555,12 @@ Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.
|
|||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.sac.policies import MlpPolicy
|
||||
|
||||
# Create the model, the training environment
|
||||
# and the test environment (for evaluation)
|
||||
# Create the model and the training environment
|
||||
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1,
|
||||
learning_rate=1e-3, create_eval_env=True)
|
||||
learning_rate=1e-3)
|
||||
|
||||
# Evaluate the model every 1000 steps on 5 test episodes
|
||||
# and save the evaluation to the "logs/" folder
|
||||
model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/")
|
||||
# train the model
|
||||
model.learn(total_timesteps=6000)
|
||||
|
||||
# save the model
|
||||
model.save("sac_pendulum")
|
||||
|
|
|
|||
|
|
@ -206,8 +206,6 @@ New Features (SB3 vs SB2)
|
|||
- Independent saving/loading/predict for policies
|
||||
- A2C now supports Generalized Advantage Estimation (GAE) and advantage normalization (both are deactivated by default)
|
||||
- Generalized State-Dependent Exploration (gSDE) exploration is available for A2C/PPO/SAC. It allows to use RL directly on real robots (cf https://arxiv.org/abs/2005.05719)
|
||||
- Proper evaluation (using separate env) is included in the base class (using ``EvalCallback``),
|
||||
if you pass the environment as a string, you can pass ``create_eval_env=True`` to the algorithm constructor.
|
||||
- Better saving/loading: optimizers are now included in the saved parameters and there is two new methods ``save_replay_buffer`` and ``load_replay_buffer`` for the replay buffer when using off-policy algorithms (DQN/DDPG/SAC/TD3)
|
||||
- You can pass ``optimizer_class`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily
|
||||
customize optimizers
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ Bug Fixes:
|
|||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
- Added deprecation warning if parameters ``eval_env``, ``eval_freq`` or ``create_eval_env`` are used (see #925) (@tobirohrer)
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
|
@ -1074,4 +1075,4 @@ And all the contributors:
|
|||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer
|
||||
|
|
|
|||
|
|
@ -44,7 +44,9 @@ class A2C(OnPolicyAlgorithm):
|
|||
:param normalize_advantage: Whether to normalize or not the advantage
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
||||
debug messages
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import io
|
||||
import pathlib
|
||||
import time
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
|
@ -75,7 +76,9 @@ class BaseAlgorithm(ABC):
|
|||
:param support_multi_env: Whether the algorithm supports training
|
||||
with multiple environments (as in A2C)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: Seed for the pseudo random generators
|
||||
|
|
@ -161,6 +164,15 @@ class BaseAlgorithm(ABC):
|
|||
if env is not None:
|
||||
if isinstance(env, str):
|
||||
if create_eval_env:
|
||||
warnings.warn(
|
||||
"The parameter `create_eval_env` is deprecated and will be removed in the future. "
|
||||
"Please use `EvalCallback` or a custom Callback instead.",
|
||||
DeprecationWarning,
|
||||
# By setting the `stacklevel` we refer to the initial caller of the deprecated feature.
|
||||
# This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See
|
||||
# https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details.
|
||||
stacklevel=4,
|
||||
)
|
||||
self.eval_env = maybe_make_env(env, self.verbose)
|
||||
|
||||
env = maybe_make_env(env, self.verbose)
|
||||
|
|
@ -376,6 +388,8 @@ class BaseAlgorithm(ABC):
|
|||
"""
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param eval_freq: How many steps between evaluations; if None, do not evaluate.
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param n_eval_episodes: How many episodes to play per evaluation
|
||||
:param n_eval_episodes: Number of episodes to rollout during evaluation.
|
||||
:param log_path: Path to a folder where the evaluations will be saved
|
||||
|
|
@ -426,8 +440,12 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
:param total_timesteps: The total number of samples (env steps) to train on
|
||||
:param eval_env: Environment to use for evaluation.
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param eval_freq: How many steps between evaluations
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param n_eval_episodes: How many episodes to play per evaluation
|
||||
:param log_path: Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||
|
|
@ -435,6 +453,18 @@ class BaseAlgorithm(ABC):
|
|||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
:return: Total timesteps and callback(s)
|
||||
"""
|
||||
|
||||
if eval_env is not None or eval_freq != -1:
|
||||
warnings.warn(
|
||||
"Parameters `eval_env` and `eval_freq` are deprecated and will be removed in the future. "
|
||||
"Please use `EvalCallback` or a custom Callback instead.",
|
||||
DeprecationWarning,
|
||||
# By setting the `stacklevel` we refer to the initial caller of the deprecated feature.
|
||||
# This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See
|
||||
# https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details.
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
self.start_time = time.time_ns()
|
||||
|
||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||
|
|
@ -567,8 +597,11 @@ class BaseAlgorithm(ABC):
|
|||
:param callback: callback(s) called at every step with state of the algorithm.
|
||||
:param log_interval: The number of timesteps before logging.
|
||||
:param tb_log_name: the name of the run for TensorBoard logging
|
||||
:param eval_env: Environment that will be used to evaluate the agent
|
||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
|
||||
:param eval_env: Environment that will be used to evaluate the agent. Caution, this parameter
|
||||
is deprecated and will be removed in the future. Please use ``EvalCallback`` instead.
|
||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param n_eval_episodes: Number of episode to evaluate the agent
|
||||
:param eval_log_path: Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
with multiple environments (as in A2C)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: Seed for the pseudo random generators
|
||||
|
|
|
|||
|
|
@ -40,6 +40,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
|
|
|
|||
|
|
@ -45,7 +45,9 @@ class DDPG(TD3):
|
|||
at a cost of more complexity.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
||||
debug messages
|
||||
|
|
|
|||
|
|
@ -53,7 +53,9 @@ class DQN(OffPolicyAlgorithm):
|
|||
:param max_grad_norm: The maximum value for the gradient clipping
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
||||
debug messages
|
||||
|
|
|
|||
|
|
@ -58,7 +58,9 @@ class PPO(OnPolicyAlgorithm):
|
|||
By default, there is no limit on the kl div.
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
||||
debug messages
|
||||
|
|
|
|||
|
|
@ -66,7 +66,9 @@ class SAC(OffPolicyAlgorithm):
|
|||
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
||||
debug messages
|
||||
|
|
|
|||
|
|
@ -54,7 +54,9 @@ class TD3(OffPolicyAlgorithm):
|
|||
(smoothing noise)
|
||||
:param target_noise_clip: Limit for absolute value of target policy smoothing noise.
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
||||
debug messages
|
||||
|
|
|
|||
|
|
@ -47,7 +47,10 @@ def test_auto_wrap(model_class):
|
|||
env = gym.make(env_name)
|
||||
eval_env = gym.make(env_name)
|
||||
model = model_class("MlpPolicy", env)
|
||||
model.learn(100, eval_env=eval_env)
|
||||
|
||||
# Catch DeprecationWarnings
|
||||
with pytest.warns(DeprecationWarning): # `eval_env` is deprecated
|
||||
model.learn(100, eval_env=eval_env)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
|
|
|
|||
|
|
@ -18,23 +18,25 @@ def test_deterministic_pg(model_class, action_noise):
|
|||
"""
|
||||
Test for DDPG and variants (TD3).
|
||||
"""
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
buffer_size=250,
|
||||
action_noise=action_noise,
|
||||
)
|
||||
model.learn(total_timesteps=300, eval_freq=250)
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
buffer_size=250,
|
||||
action_noise=action_noise,
|
||||
)
|
||||
model.learn(total_timesteps=300, eval_freq=250)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||
def test_a2c(env_id):
|
||||
model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
|
|
@ -47,46 +49,48 @@ def test_advantage_normalization(model_class, normalize_advantage):
|
|||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||
@pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2])
|
||||
def test_ppo(env_id, clip_range_vf):
|
||||
if clip_range_vf is not None and clip_range_vf < 0:
|
||||
# Should throw an error
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
if clip_range_vf is not None and clip_range_vf < 0:
|
||||
# Should throw an error
|
||||
with pytest.raises(AssertionError):
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env_id,
|
||||
seed=0,
|
||||
policy_kwargs=dict(net_arch=[16]),
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
clip_range_vf=clip_range_vf,
|
||||
)
|
||||
else:
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env_id,
|
||||
n_steps=512,
|
||||
seed=0,
|
||||
policy_kwargs=dict(net_arch=[16]),
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
clip_range_vf=clip_range_vf,
|
||||
)
|
||||
else:
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env_id,
|
||||
n_steps=512,
|
||||
seed=0,
|
||||
policy_kwargs=dict(net_arch=[16]),
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
clip_range_vf=clip_range_vf,
|
||||
)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||
def test_sac(ent_coef):
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
buffer_size=250,
|
||||
ent_coef=ent_coef,
|
||||
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
|
||||
)
|
||||
model.learn(total_timesteps=300, eval_freq=250)
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
buffer_size=250,
|
||||
ent_coef=ent_coef,
|
||||
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
|
||||
)
|
||||
model.learn(total_timesteps=300, eval_freq=250)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||
|
|
@ -104,17 +108,18 @@ def test_n_critics(n_critics):
|
|||
|
||||
|
||||
def test_dqn():
|
||||
model = DQN(
|
||||
"MlpPolicy",
|
||||
"CartPole-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
buffer_size=500,
|
||||
learning_rate=3e-4,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model = DQN(
|
||||
"MlpPolicy",
|
||||
"CartPole-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
buffer_size=500,
|
||||
learning_rate=3e-4,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])
|
||||
|
|
|
|||
|
|
@ -63,17 +63,18 @@ def test_sde_check():
|
|||
@pytest.mark.parametrize("use_expln", [False, True])
|
||||
def test_state_dependent_noise(model_class, use_expln):
|
||||
kwargs = {"learning_starts": 0} if model_class == SAC else {"n_steps": 64}
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
use_sde=True,
|
||||
seed=None,
|
||||
create_eval_env=True,
|
||||
verbose=1,
|
||||
policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]),
|
||||
**kwargs,
|
||||
)
|
||||
model.learn(total_timesteps=255, eval_freq=250)
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
use_sde=True,
|
||||
seed=None,
|
||||
create_eval_env=True,
|
||||
verbose=1,
|
||||
policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]),
|
||||
**kwargs,
|
||||
)
|
||||
model.learn(total_timesteps=255, eval_freq=250)
|
||||
model.policy.reset_noise()
|
||||
if model_class == SAC:
|
||||
model.policy.actor.get_std()
|
||||
|
|
|
|||
|
|
@ -380,8 +380,8 @@ def test_offpolicy_normalization(model_class, online_sampling):
|
|||
assert model.get_vec_normalize_env() is eval_env
|
||||
model.learn(total_timesteps=10)
|
||||
model.set_env(env)
|
||||
|
||||
model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75)
|
||||
with pytest.warns(DeprecationWarning): # `eval_env` and `eval_freq` are deprecated
|
||||
model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75)
|
||||
# Check getter
|
||||
assert isinstance(model.get_vec_normalize_env(), VecNormalize)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue