mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Use Monitor episode reward/length for evaluate_policy (#220)
* Update evaluate_policy to use monitor data if available * Update documentation * Cleaning up * Remove unnecessary typing trickery * Update doc * Rename is_wrapped to clarify it is for vecenvs * Add is_wrapped for regular envs * Add is_wrapped call for subprocvecenv and update code for circular imports * Move new functions back to env_util and fix imports * Update changelog * Clarify evaluate_policy docs * Add tests for wrapped modifying episode lengths * Fix tests * Update changelog * Minor edits * Add warn switch to evaluate_policy and update tests Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
c74509ae9d
commit
18d10dbf42
19 changed files with 305 additions and 34 deletions
|
|
@ -79,6 +79,9 @@ In the following example, we will train, save and load a DQN model on the Lunar
|
|||
model = DQN.load("dqn_lunar")
|
||||
|
||||
# Evaluate the agent
|
||||
# NOTE: If you use wrappers with your environment that modify rewards,
|
||||
# this will be reflected here. To evaluate with original rewards,
|
||||
# wrap environment in a "Monitor" wrapper before other wrappers.
|
||||
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
|
||||
|
||||
# Enjoy trained agent
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ TL;DR
|
|||
|
||||
1. Read about RL and Stable Baselines3
|
||||
2. Do quantitative experiments and hyperparameter tuning if needed
|
||||
3. Evaluate the performance using a separate test environment
|
||||
3. Evaluate the performance using a separate test environment (remember to check wrappers!)
|
||||
4. For better performance, increase the training budget
|
||||
|
||||
|
||||
|
|
@ -68,18 +68,24 @@ Other method, like ``TRPO`` or ``PPO`` make use of a *trust region* to minimize
|
|||
How to evaluate an RL algorithm?
|
||||
--------------------------------
|
||||
|
||||
.. note::
|
||||
|
||||
Pay attention to environment wrappers when evaluating your agent and comparing results to others' results. Modifications to episode rewards
|
||||
or lengths may also affect evaluation results which may not be desirable. Check ``evaluate_policy`` helper function in :ref:`Evaluation Helper <eval>` section.
|
||||
|
||||
Because most algorithms use exploration noise during training, you need a separate test environment to evaluate the performance
|
||||
of your agent at a given time. It is recommended to periodically evaluate your agent for ``n`` test episodes (``n`` is usually between 5 and 20)
|
||||
and average the reward per episode to have a good estimate.
|
||||
|
||||
.. note::
|
||||
|
||||
We provide an ``EvalCallback`` for doing such evaluation. You can read more about it in the :ref:`Callbacks <callbacks>` section.
|
||||
|
||||
As some policy are stochastic by default (e.g. A2C or PPO), you should also try to set `deterministic=True` when calling the `.predict()` method,
|
||||
this frequently leads to better performance.
|
||||
Looking at the training curve (episode reward function of the timesteps) is a good proxy but underestimates the agent true performance.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
We provide an ``EvalCallback`` for doing such evaluation. You can read more about it in the :ref:`Callbacks <callbacks>` section.
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,10 @@ Pre-Release 0.11.0a0 (WIP)
|
|||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
- ``evaluate_policy`` now returns rewards/episode lengths from a ``Monitor`` wrapper if one is present,
|
||||
this allows to return the unnormalized reward in the case of Atari games for instance.
|
||||
- Renamed ``common.vec_env.is_wrapped`` to ``common.vec_env.is_vecenv_wrapped`` to avoid confusion
|
||||
with the new ``is_wrapped()`` helper
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -16,6 +19,10 @@ New Features:
|
|||
automatic check for image spaces.
|
||||
- ``VecFrameStack`` now has a ``channels_order`` argument to tell if observations should be stacked
|
||||
on the first or last observation dimension (originally always stacked on last).
|
||||
- Added ``common.env_util.is_wrapped`` and ``common.env_util.unwrap_wrapper`` functions for checking/unwrapping
|
||||
an environment for specific wrapper.
|
||||
- Added ``env_is_wrapped()`` method for ``VecEnv`` to check if its environments are wrapped
|
||||
with given Gym wrappers.
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from stable_baselines3.common.vec_env import (
|
|||
VecEnv,
|
||||
VecNormalize,
|
||||
VecTransposeImage,
|
||||
is_wrapped,
|
||||
is_vecenv_wrapped,
|
||||
unwrap_vec_normalize,
|
||||
)
|
||||
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
|
||||
|
|
@ -178,7 +178,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
if (
|
||||
is_image_space(env.observation_space)
|
||||
and not is_wrapped(env, VecTransposeImage)
|
||||
and not is_vecenv_wrapped(env, VecTransposeImage)
|
||||
and not is_image_space_channels_first(env.observation_space)
|
||||
):
|
||||
if verbose >= 1:
|
||||
|
|
|
|||
|
|
@ -276,6 +276,8 @@ class EvalCallback(EventCallback):
|
|||
:param deterministic: Whether to render or not the environment during evaluation
|
||||
:param render: Whether to render or not the environment during evaluation
|
||||
:param verbose:
|
||||
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
|
||||
wrapped with a Monitor wrapper)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -289,6 +291,7 @@ class EvalCallback(EventCallback):
|
|||
deterministic: bool = True,
|
||||
render: bool = False,
|
||||
verbose: int = 1,
|
||||
warn: bool = True,
|
||||
):
|
||||
super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose)
|
||||
self.n_eval_episodes = n_eval_episodes
|
||||
|
|
@ -297,6 +300,7 @@ class EvalCallback(EventCallback):
|
|||
self.last_mean_reward = -np.inf
|
||||
self.deterministic = deterministic
|
||||
self.render = render
|
||||
self.warn = warn
|
||||
|
||||
# Convert to VecEnv for consistency
|
||||
if not isinstance(eval_env, VecEnv):
|
||||
|
|
@ -339,6 +343,7 @@ class EvalCallback(EventCallback):
|
|||
render=self.render,
|
||||
deterministic=self.deterministic,
|
||||
return_episode_rewards=True,
|
||||
warn=self.warn,
|
||||
)
|
||||
|
||||
if self.log_path is not None:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,33 @@ from stable_baselines3.common.monitor import Monitor
|
|||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
|
||||
|
||||
|
||||
def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
|
||||
"""
|
||||
Retrieve a ``VecEnvWrapper`` object by recursively searching.
|
||||
|
||||
:param env: Environment to unwrap
|
||||
:param wrapper_class: Wrapper to look for
|
||||
:return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
|
||||
"""
|
||||
env_tmp = env
|
||||
while isinstance(env_tmp, gym.Wrapper):
|
||||
if isinstance(env_tmp, wrapper_class):
|
||||
return env_tmp
|
||||
env_tmp = env_tmp.env
|
||||
return None
|
||||
|
||||
|
||||
def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
|
||||
"""
|
||||
Check if a given environment has been wrapped with a given wrapper.
|
||||
|
||||
:param env: Environment to check
|
||||
:param wrapper_class: Wrapper class to look for
|
||||
:return: True if environment has been wrapped with ``wrapper_class``.
|
||||
"""
|
||||
return unwrap_wrapper(env, wrapper_class) is not None
|
||||
|
||||
|
||||
def make_vec_env(
|
||||
env_id: Union[str, Type[gym.Env]],
|
||||
n_envs: int = 1,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
|
|
@ -16,11 +17,20 @@ def evaluate_policy(
|
|||
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
|
||||
reward_threshold: Optional[float] = None,
|
||||
return_episode_rewards: bool = False,
|
||||
warn: bool = True,
|
||||
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
|
||||
"""
|
||||
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
|
||||
This is made to work only with one env.
|
||||
|
||||
.. note::
|
||||
If environment has not been wrapped with ``Monitor`` wrapper, reward and
|
||||
episode lengths are counted as it appears with ``env.step`` calls. If
|
||||
the environment contains wrappers that modify rewards or episode lengths
|
||||
(e.g. reward scaling, early episode reset), these will affect the evaluation
|
||||
results as well. You can avoid this by wrapping environment with ``Monitor``
|
||||
wrapper before anything else.
|
||||
|
||||
:param model: The RL agent you want to evaluate.
|
||||
:param env: The gym environment. In the case of a ``VecEnv``
|
||||
this must contain only one environment.
|
||||
|
|
@ -31,33 +41,70 @@ def evaluate_policy(
|
|||
called after each step. Gets locals() and globals() passed as parameters.
|
||||
:param reward_threshold: Minimum expected reward per episode,
|
||||
this will raise an error if the performance is not met
|
||||
:param return_episode_rewards: If True, a list of reward per episode
|
||||
will be returned instead of the mean.
|
||||
:return: Mean reward per episode, std of reward per episode
|
||||
returns ([float], [int]) when ``return_episode_rewards`` is True
|
||||
:param return_episode_rewards: If True, a list of rewards and episde lengths
|
||||
per episode will be returned instead of the mean.
|
||||
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
|
||||
evaluation environment.
|
||||
:return: Mean reward per episode, std of reward per episode.
|
||||
Returns ([float], [int]) when ``return_episode_rewards`` is True, first
|
||||
list containing per-episode rewards and second containing per-episode lengths
|
||||
(in number of steps).
|
||||
"""
|
||||
is_monitor_wrapped = False
|
||||
# Avoid circular import
|
||||
from stable_baselines3.common.env_util import is_wrapped
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
|
||||
if isinstance(env, VecEnv):
|
||||
assert env.num_envs == 1, "You must pass only one environment when using this function"
|
||||
is_monitor_wrapped = env.env_is_wrapped(Monitor)[0]
|
||||
else:
|
||||
is_monitor_wrapped = is_wrapped(env, Monitor)
|
||||
|
||||
if not is_monitor_wrapped and warn:
|
||||
warnings.warn(
|
||||
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
|
||||
"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
|
||||
"Consider wrapping environment first with ``Monitor`` wrapper.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
episode_rewards, episode_lengths = [], []
|
||||
for i in range(n_eval_episodes):
|
||||
# Avoid double reset, as VecEnv are reset automatically
|
||||
if not isinstance(env, VecEnv) or i == 0:
|
||||
not_reseted = True
|
||||
while len(episode_rewards) < n_eval_episodes:
|
||||
# Number of loops here might differ from true episodes
|
||||
# played, if underlying wrappers modify episode lengths.
|
||||
# Avoid double reset, as VecEnv are reset automatically.
|
||||
if not isinstance(env, VecEnv) or not_reseted:
|
||||
obs = env.reset()
|
||||
not_reseted = False
|
||||
done, state = False, None
|
||||
episode_reward = 0.0
|
||||
episode_length = 0
|
||||
while not done:
|
||||
action, state = model.predict(obs, state=state, deterministic=deterministic)
|
||||
obs, reward, done, _info = env.step(action)
|
||||
obs, reward, done, info = env.step(action)
|
||||
episode_reward += reward
|
||||
if callback is not None:
|
||||
callback(locals(), globals())
|
||||
episode_length += 1
|
||||
if render:
|
||||
env.render()
|
||||
episode_rewards.append(episode_reward)
|
||||
episode_lengths.append(episode_length)
|
||||
|
||||
if is_monitor_wrapped:
|
||||
# Do not trust "done" with episode endings.
|
||||
# Remove vecenv stacking (if any)
|
||||
if isinstance(env, VecEnv):
|
||||
info = info[0]
|
||||
if "episode" in info.keys():
|
||||
# Monitor wrapper includes "episode" key in info if environment
|
||||
# has been wrapped with it. Use those rewards instead.
|
||||
episode_rewards.append(info["episode"]["r"])
|
||||
episode_lengths.append(info["episode"]["l"])
|
||||
else:
|
||||
episode_rewards.append(episode_reward)
|
||||
episode_lengths.append(episode_length)
|
||||
|
||||
mean_reward = np.mean(episode_rewards)
|
||||
std_reward = np.std(episode_rewards)
|
||||
if reward_threshold is not None:
|
||||
|
|
|
|||
|
|
@ -6,10 +6,9 @@ import gym
|
|||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3.common import callbacks
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.common import callbacks, vec_env
|
||||
|
||||
GymEnv = Union[gym.Env, VecEnv]
|
||||
GymEnv = Union[gym.Env, vec_env.VecEnv]
|
||||
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
|
||||
GymStepReturn = Tuple[GymObs, float, bool, Dict]
|
||||
TensorDict = Dict[str, th.Tensor]
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]
|
|||
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
|
||||
|
||||
|
||||
def is_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
|
||||
def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
|
||||
"""
|
||||
Check if an environment is already wrapped by a given ``VecEnvWrapper``.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
import cloudpickle
|
||||
import gym
|
||||
|
|
@ -139,6 +139,19 @@ class VecEnv(ABC):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
"""
|
||||
Check if environments are wrapped with a given wrapper.
|
||||
|
||||
:param method_name: The name of the environment method to invoke.
|
||||
:param indices: Indices of envs whose method to call
|
||||
:param method_args: Any positional arguments to provide in the call
|
||||
:param method_kwargs: Any keyword arguments to provide in the call
|
||||
:return: True if the env is wrapped, False otherwise, for each env queried.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
|
||||
"""
|
||||
Step the environments with the given action
|
||||
|
|
@ -280,6 +293,9 @@ class VecEnvWrapper(VecEnv):
|
|||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
|
||||
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
return self.venv.env_is_wrapped(wrapper_class, indices=indices)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
|
||||
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, Optional, Sequence, Union
|
||||
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -112,6 +112,14 @@ class DummyVecEnv(VecEnv):
|
|||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
"""Check if worker environments are wrapped with a given wrapper"""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
# Import here to avoid a circular import
|
||||
from stable_baselines3.common import env_util
|
||||
|
||||
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
|
||||
|
||||
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
|
||||
indices = self._get_indices(indices)
|
||||
return [self.envs[i] for i in indices]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import multiprocessing as mp
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -17,6 +17,9 @@ from stable_baselines3.common.vec_env.base_vec_env import (
|
|||
def _worker(
|
||||
remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper
|
||||
) -> None:
|
||||
# Import here to avoid a circular import
|
||||
from stable_baselines3.common.env_util import is_wrapped
|
||||
|
||||
parent_remote.close()
|
||||
env = env_fn_wrapper.var()
|
||||
while True:
|
||||
|
|
@ -49,6 +52,8 @@ def _worker(
|
|||
remote.send(getattr(env, data))
|
||||
elif cmd == "set_attr":
|
||||
remote.send(setattr(env, data[0], data[1]))
|
||||
elif cmd == "is_wrapped":
|
||||
remote.send(is_wrapped(env, data))
|
||||
else:
|
||||
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
|
||||
except EOFError:
|
||||
|
|
@ -170,6 +175,13 @@ class SubprocVecEnv(VecEnv):
|
|||
remote.send(("env_method", (method_name, method_args, method_kwargs)))
|
||||
return [remote.recv() for remote in target_remotes]
|
||||
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
"""Check if worker environments are wrapped with a given wrapper"""
|
||||
target_remotes = self._get_target_remotes(indices)
|
||||
for remote in target_remotes:
|
||||
remote.send(("is_wrapped", wrapper_class))
|
||||
return [remote.recv() for remote in target_remotes]
|
||||
|
||||
def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
|
||||
"""
|
||||
Get the connection object needed to communicate with the wanted
|
||||
|
|
|
|||
|
|
@ -33,7 +33,12 @@ def test_callbacks(tmp_path, model_class):
|
|||
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
|
||||
|
||||
eval_callback = EvalCallback(
|
||||
eval_env, callback_on_new_best=callback_on_best, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100
|
||||
eval_env,
|
||||
callback_on_new_best=callback_on_best,
|
||||
best_model_save_path=log_folder,
|
||||
log_path=log_folder,
|
||||
eval_freq=100,
|
||||
warn=False,
|
||||
)
|
||||
# Equivalent to the `checkpoint_callback`
|
||||
# but here in an event-driven manner
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
|||
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_wrapped
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
|
||||
|
|
@ -29,7 +29,7 @@ def test_cnn(tmp_path, model_class):
|
|||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||
|
||||
# FakeImageEnv is channel last by default and should be wrapped
|
||||
assert is_wrapped(model.get_env(), VecTransposeImage)
|
||||
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ def test_channel_first_env(tmp_path):
|
|||
|
||||
model = A2C("CnnPolicy", env, n_steps=100).learn(250)
|
||||
|
||||
assert not is_wrapped(model.get_env(), VecTransposeImage)
|
||||
assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage)
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def test_discrete(model_class, env):
|
|||
|
||||
model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
||||
|
||||
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90)
|
||||
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
obs = env.reset()
|
||||
|
||||
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
|
||||
|
|
@ -45,4 +45,4 @@ def test_continuous(model_class):
|
|||
|
||||
model = model_class("MlpPolicy", env, **kwargs).learn(n_steps)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
|
||||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
|
|
|
|||
|
|
@ -48,4 +48,4 @@ def test_identity_spaces(model_class, env):
|
|||
model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64]))
|
||||
model.learn(total_timesteps=500)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=5)
|
||||
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch as th
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.atari_wrappers import ClipRewardEnv
|
||||
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
|
||||
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
|
||||
|
|
@ -127,6 +127,103 @@ def test_evaluate_policy():
|
|||
episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
|
||||
assert len(episode_rewards) == n_eval_episodes
|
||||
|
||||
# Test that warning is given about no monitor
|
||||
eval_env = gym.make("Pendulum-v0")
|
||||
with pytest.warns(UserWarning):
|
||||
_ = evaluate_policy(model, eval_env, n_eval_episodes)
|
||||
|
||||
|
||||
class ZeroRewardWrapper(gym.RewardWrapper):
|
||||
def reward(self, reward):
|
||||
return reward * 0
|
||||
|
||||
|
||||
class AlwaysDoneWrapper(gym.Wrapper):
|
||||
# Pretends that environment only has single step for each
|
||||
# episode.
|
||||
def __init__(self, env):
|
||||
super(AlwaysDoneWrapper, self).__init__(env)
|
||||
self.last_obs = None
|
||||
self.needs_reset = True
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.needs_reset = done
|
||||
self.last_obs = obs
|
||||
return obs, reward, True, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
if self.needs_reset:
|
||||
obs = self.env.reset(**kwargs)
|
||||
self.last_obs = obs
|
||||
self.needs_reset = False
|
||||
return self.last_obs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("vec_env_class", [None, DummyVecEnv, SubprocVecEnv])
|
||||
def test_evaluate_policy_monitors(vec_env_class):
|
||||
# Test that results are correct with monitor environments.
|
||||
# Also test VecEnvs
|
||||
n_eval_episodes = 2
|
||||
env_id = "CartPole-v0"
|
||||
model = A2C("MlpPolicy", env_id, seed=0)
|
||||
|
||||
def make_eval_env(with_monitor, wrapper_class=gym.Wrapper):
|
||||
# Make eval environment with or without monitor in root,
|
||||
# and additionally wrapped with another wrapper (after Monitor).
|
||||
env = None
|
||||
if vec_env_class is None:
|
||||
# No vecenv, traditional env
|
||||
env = gym.make(env_id)
|
||||
if with_monitor:
|
||||
env = Monitor(env)
|
||||
env = wrapper_class(env)
|
||||
else:
|
||||
if with_monitor:
|
||||
env = vec_env_class([lambda: wrapper_class(Monitor(gym.make(env_id)))])
|
||||
else:
|
||||
env = vec_env_class([lambda: wrapper_class(gym.make(env_id))])
|
||||
return env
|
||||
|
||||
# Test that evaluation with VecEnvs works as expected
|
||||
eval_env = make_eval_env(with_monitor=True)
|
||||
_ = evaluate_policy(model, eval_env, n_eval_episodes)
|
||||
eval_env.close()
|
||||
|
||||
# Warning without Monitor
|
||||
eval_env = make_eval_env(with_monitor=False)
|
||||
with pytest.warns(UserWarning):
|
||||
_ = evaluate_policy(model, eval_env, n_eval_episodes)
|
||||
eval_env.close()
|
||||
|
||||
# Test that we gather correct reward with Monitor wrapper
|
||||
# Sanity check that we get zero-reward without Monitor
|
||||
eval_env = make_eval_env(with_monitor=False, wrapper_class=ZeroRewardWrapper)
|
||||
average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes, warn=False)
|
||||
assert average_reward == 0.0, "ZeroRewardWrapper wrapper for testing did not work"
|
||||
eval_env.close()
|
||||
|
||||
# Should get non-zero-rewards with Monitor (true reward)
|
||||
eval_env = make_eval_env(with_monitor=True, wrapper_class=ZeroRewardWrapper)
|
||||
average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes)
|
||||
assert average_reward > 0.0, "evaluate_policy did not get reward from Monitor"
|
||||
eval_env.close()
|
||||
|
||||
# Test that we also track correct episode dones, not the wrapped ones.
|
||||
# Sanity check that we get only one step per episode.
|
||||
eval_env = make_eval_env(with_monitor=False, wrapper_class=AlwaysDoneWrapper)
|
||||
episode_rewards, episode_lengths = evaluate_policy(
|
||||
model, eval_env, n_eval_episodes, return_episode_rewards=True, warn=False
|
||||
)
|
||||
assert all(map(lambda l: l == 1, episode_lengths)), "AlwaysDoneWrapper did not fix episode lengths to one"
|
||||
eval_env.close()
|
||||
|
||||
# Should get longer episodes with with Monitor (true episodes)
|
||||
eval_env = make_eval_env(with_monitor=True, wrapper_class=AlwaysDoneWrapper)
|
||||
episode_rewards, episode_lengths = evaluate_policy(model, eval_env, n_eval_episodes, return_episode_rewards=True)
|
||||
assert all(map(lambda l: l > 1, episode_lengths)), "evaluate_policy did not get episode lengths from Monitor"
|
||||
eval_env.close()
|
||||
|
||||
|
||||
def test_vec_noise():
|
||||
num_envs = 4
|
||||
|
|
@ -196,3 +293,16 @@ def test_cmd_util_rename():
|
|||
"""Test that importing cmd_util still works but raises warning"""
|
||||
with pytest.warns(FutureWarning):
|
||||
from stable_baselines3.common.cmd_util import make_vec_env # noqa: F401
|
||||
|
||||
|
||||
def test_is_wrapped():
|
||||
"""Test that is_wrapped correctly detects wraps"""
|
||||
env = gym.make("Pendulum-v0")
|
||||
env = gym.Wrapper(env)
|
||||
assert not is_wrapped(env, Monitor)
|
||||
monitor_env = Monitor(env)
|
||||
assert is_wrapped(monitor_env, Monitor)
|
||||
env = gym.Wrapper(monitor_env)
|
||||
assert is_wrapped(env, Monitor)
|
||||
# Test that unwrap works as expected
|
||||
assert unwrap_wrapper(env, Monitor) == monitor_env
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import gym
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
|
||||
|
||||
N_ENVS = 3
|
||||
|
|
@ -415,3 +416,27 @@ def test_framestack_vecenv():
|
|||
# Test that it works with non-image envs when no channels_order is given
|
||||
vec_env = DummyVecEnv([make_non_image_env for _ in range(N_ENVS)])
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2)
|
||||
|
||||
|
||||
def test_vec_env_is_wrapped():
|
||||
# Test is_wrapped call of subproc workers
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
|
||||
def make_monitored_env():
|
||||
return Monitor(CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))))
|
||||
|
||||
# One with monitor, one without
|
||||
vec_env = SubprocVecEnv([make_env, make_monitored_env])
|
||||
|
||||
assert vec_env.env_is_wrapped(Monitor) == [False, True]
|
||||
|
||||
vec_env.close()
|
||||
|
||||
# One with monitor, one without
|
||||
vec_env = DummyVecEnv([make_env, make_monitored_env])
|
||||
|
||||
assert vec_env.env_is_wrapped(Monitor) == [False, True]
|
||||
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2)
|
||||
assert vec_env.env_is_wrapped(Monitor) == [False, True]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
from gym import spaces
|
||||
|
||||
from stable_baselines3 import HER, SAC, TD3
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
||||
from stable_baselines3.common.vec_env import (
|
||||
DummyVecEnv,
|
||||
|
|
@ -61,11 +62,11 @@ def allclose(obs_1, obs_2):
|
|||
|
||||
|
||||
def make_env():
|
||||
return gym.make(ENV_ID)
|
||||
return Monitor(gym.make(ENV_ID))
|
||||
|
||||
|
||||
def make_dict_env():
|
||||
return DummyDictEnv()
|
||||
return Monitor(DummyDictEnv())
|
||||
|
||||
|
||||
def check_rms_equal(rmsa, rmsb):
|
||||
|
|
|
|||
Loading…
Reference in a new issue