From 18d10dbf42dd6dff6d457b45b521fdf2a1169a7e Mon Sep 17 00:00:00 2001 From: Anssi Date: Mon, 16 Nov 2020 12:52:28 +0200 Subject: [PATCH] Use Monitor episode reward/length for `evaluate_policy` (#220) * Update evaluate_policy to use monitor data if available * Update documentation * Cleaning up * Remove unnecessary typing trickery * Update doc * Rename is_wrapped to clarify it is for vecenvs * Add is_wrapped for regular envs * Add is_wrapped call for subprocvecenv and update code for circular imports * Move new functions back to env_util and fix imports * Update changelog * Clarify evaluate_policy docs * Add tests for wrapped modifying episode lengths * Fix tests * Update changelog * Minor edits * Add warn switch to evaluate_policy and update tests Co-authored-by: Antonin RAFFIN --- docs/guide/examples.rst | 3 + docs/guide/rl_tips.rst | 14 ++- docs/misc/changelog.rst | 9 +- stable_baselines3/common/base_class.py | 4 +- stable_baselines3/common/callbacks.py | 5 + stable_baselines3/common/env_util.py | 27 +++++ stable_baselines3/common/evaluation.py | 67 +++++++++-- stable_baselines3/common/type_aliases.py | 5 +- stable_baselines3/common/vec_env/__init__.py | 2 +- .../common/vec_env/base_vec_env.py | 18 ++- .../common/vec_env/dummy_vec_env.py | 10 +- .../common/vec_env/subproc_vec_env.py | 14 ++- tests/test_callbacks.py | 7 +- tests/test_cnn.py | 6 +- tests/test_identity.py | 4 +- tests/test_spaces.py | 2 +- tests/test_utils.py | 112 +++++++++++++++++- tests/test_vec_envs.py | 25 ++++ tests/test_vec_normalize.py | 5 +- 19 files changed, 305 insertions(+), 34 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 2e9f2b2..4b85e02 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -79,6 +79,9 @@ In the following example, we will train, save and load a DQN model on the Lunar model = DQN.load("dqn_lunar") # Evaluate the agent + # NOTE: If you use wrappers with your environment that modify rewards, + # this will be reflected here. To evaluate with original rewards, + # wrap environment in a "Monitor" wrapper before other wrappers. mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10) # Enjoy trained agent diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index 535d04a..29199a1 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -17,7 +17,7 @@ TL;DR 1. Read about RL and Stable Baselines3 2. Do quantitative experiments and hyperparameter tuning if needed -3. Evaluate the performance using a separate test environment +3. Evaluate the performance using a separate test environment (remember to check wrappers!) 4. For better performance, increase the training budget @@ -68,18 +68,24 @@ Other method, like ``TRPO`` or ``PPO`` make use of a *trust region* to minimize How to evaluate an RL algorithm? -------------------------------- +.. note:: + + Pay attention to environment wrappers when evaluating your agent and comparing results to others' results. Modifications to episode rewards + or lengths may also affect evaluation results which may not be desirable. Check ``evaluate_policy`` helper function in :ref:`Evaluation Helper ` section. + Because most algorithms use exploration noise during training, you need a separate test environment to evaluate the performance of your agent at a given time. It is recommended to periodically evaluate your agent for ``n`` test episodes (``n`` is usually between 5 and 20) and average the reward per episode to have a good estimate. +.. note:: + + We provide an ``EvalCallback`` for doing such evaluation. You can read more about it in the :ref:`Callbacks ` section. + As some policy are stochastic by default (e.g. A2C or PPO), you should also try to set `deterministic=True` when calling the `.predict()` method, this frequently leads to better performance. Looking at the training curve (episode reward function of the timesteps) is a good proxy but underestimates the agent true performance. -.. note:: - - We provide an ``EvalCallback`` for doing such evaluation. You can read more about it in the :ref:`Callbacks ` section. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4bc5907..6c4efa9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -8,7 +8,10 @@ Pre-Release 0.11.0a0 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ - +- ``evaluate_policy`` now returns rewards/episode lengths from a ``Monitor`` wrapper if one is present, + this allows to return the unnormalized reward in the case of Atari games for instance. +- Renamed ``common.vec_env.is_wrapped`` to ``common.vec_env.is_vecenv_wrapped`` to avoid confusion + with the new ``is_wrapped()`` helper New Features: ^^^^^^^^^^^^^ @@ -16,6 +19,10 @@ New Features: automatic check for image spaces. - ``VecFrameStack`` now has a ``channels_order`` argument to tell if observations should be stacked on the first or last observation dimension (originally always stacked on last). +- Added ``common.env_util.is_wrapped`` and ``common.env_util.unwrap_wrapper`` functions for checking/unwrapping + an environment for specific wrapper. +- Added ``env_is_wrapped()`` method for ``VecEnv`` to check if its environments are wrapped + with given Gym wrappers. Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 58bdf88..c99e713 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -31,7 +31,7 @@ from stable_baselines3.common.vec_env import ( VecEnv, VecNormalize, VecTransposeImage, - is_wrapped, + is_vecenv_wrapped, unwrap_vec_normalize, ) from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper @@ -178,7 +178,7 @@ class BaseAlgorithm(ABC): if ( is_image_space(env.observation_space) - and not is_wrapped(env, VecTransposeImage) + and not is_vecenv_wrapped(env, VecTransposeImage) and not is_image_space_channels_first(env.observation_space) ): if verbose >= 1: diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 27d7748..c114a93 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -276,6 +276,8 @@ class EvalCallback(EventCallback): :param deterministic: Whether to render or not the environment during evaluation :param render: Whether to render or not the environment during evaluation :param verbose: + :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been + wrapped with a Monitor wrapper) """ def __init__( @@ -289,6 +291,7 @@ class EvalCallback(EventCallback): deterministic: bool = True, render: bool = False, verbose: int = 1, + warn: bool = True, ): super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose) self.n_eval_episodes = n_eval_episodes @@ -297,6 +300,7 @@ class EvalCallback(EventCallback): self.last_mean_reward = -np.inf self.deterministic = deterministic self.render = render + self.warn = warn # Convert to VecEnv for consistency if not isinstance(eval_env, VecEnv): @@ -339,6 +343,7 @@ class EvalCallback(EventCallback): render=self.render, deterministic=self.deterministic, return_episode_rewards=True, + warn=self.warn, ) if self.log_path is not None: diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 0458b30..2b8d1f0 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -8,6 +8,33 @@ from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv +def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]: + """ + Retrieve a ``VecEnvWrapper`` object by recursively searching. + + :param env: Environment to unwrap + :param wrapper_class: Wrapper to look for + :return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it + """ + env_tmp = env + while isinstance(env_tmp, gym.Wrapper): + if isinstance(env_tmp, wrapper_class): + return env_tmp + env_tmp = env_tmp.env + return None + + +def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool: + """ + Check if a given environment has been wrapped with a given wrapper. + + :param env: Environment to check + :param wrapper_class: Wrapper class to look for + :return: True if environment has been wrapped with ``wrapper_class``. + """ + return unwrap_wrapper(env, wrapper_class) is not None + + def make_vec_env( env_id: Union[str, Type[gym.Env]], n_envs: int = 1, diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index 0ed40e4..cb773bb 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import gym @@ -16,11 +17,20 @@ def evaluate_policy( callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, reward_threshold: Optional[float] = None, return_episode_rewards: bool = False, + warn: bool = True, ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: """ Runs policy for ``n_eval_episodes`` episodes and returns average reward. This is made to work only with one env. + .. note:: + If environment has not been wrapped with ``Monitor`` wrapper, reward and + episode lengths are counted as it appears with ``env.step`` calls. If + the environment contains wrappers that modify rewards or episode lengths + (e.g. reward scaling, early episode reset), these will affect the evaluation + results as well. You can avoid this by wrapping environment with ``Monitor`` + wrapper before anything else. + :param model: The RL agent you want to evaluate. :param env: The gym environment. In the case of a ``VecEnv`` this must contain only one environment. @@ -31,33 +41,70 @@ def evaluate_policy( called after each step. Gets locals() and globals() passed as parameters. :param reward_threshold: Minimum expected reward per episode, this will raise an error if the performance is not met - :param return_episode_rewards: If True, a list of reward per episode - will be returned instead of the mean. - :return: Mean reward per episode, std of reward per episode - returns ([float], [int]) when ``return_episode_rewards`` is True + :param return_episode_rewards: If True, a list of rewards and episde lengths + per episode will be returned instead of the mean. + :param warn: If True (default), warns user about lack of a Monitor wrapper in the + evaluation environment. + :return: Mean reward per episode, std of reward per episode. + Returns ([float], [int]) when ``return_episode_rewards`` is True, first + list containing per-episode rewards and second containing per-episode lengths + (in number of steps). """ + is_monitor_wrapped = False + # Avoid circular import + from stable_baselines3.common.env_util import is_wrapped + from stable_baselines3.common.monitor import Monitor + if isinstance(env, VecEnv): assert env.num_envs == 1, "You must pass only one environment when using this function" + is_monitor_wrapped = env.env_is_wrapped(Monitor)[0] + else: + is_monitor_wrapped = is_wrapped(env, Monitor) + + if not is_monitor_wrapped and warn: + warnings.warn( + "Evaluation environment is not wrapped with a ``Monitor`` wrapper. " + "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. " + "Consider wrapping environment first with ``Monitor`` wrapper.", + UserWarning, + ) episode_rewards, episode_lengths = [], [] - for i in range(n_eval_episodes): - # Avoid double reset, as VecEnv are reset automatically - if not isinstance(env, VecEnv) or i == 0: + not_reseted = True + while len(episode_rewards) < n_eval_episodes: + # Number of loops here might differ from true episodes + # played, if underlying wrappers modify episode lengths. + # Avoid double reset, as VecEnv are reset automatically. + if not isinstance(env, VecEnv) or not_reseted: obs = env.reset() + not_reseted = False done, state = False, None episode_reward = 0.0 episode_length = 0 while not done: action, state = model.predict(obs, state=state, deterministic=deterministic) - obs, reward, done, _info = env.step(action) + obs, reward, done, info = env.step(action) episode_reward += reward if callback is not None: callback(locals(), globals()) episode_length += 1 if render: env.render() - episode_rewards.append(episode_reward) - episode_lengths.append(episode_length) + + if is_monitor_wrapped: + # Do not trust "done" with episode endings. + # Remove vecenv stacking (if any) + if isinstance(env, VecEnv): + info = info[0] + if "episode" in info.keys(): + # Monitor wrapper includes "episode" key in info if environment + # has been wrapped with it. Use those rewards instead. + episode_rewards.append(info["episode"]["r"]) + episode_lengths.append(info["episode"]["l"]) + else: + episode_rewards.append(episode_reward) + episode_lengths.append(episode_length) + mean_reward = np.mean(episode_rewards) std_reward = np.std(episode_rewards) if reward_threshold is not None: diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 1577423..d189f5d 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -6,10 +6,9 @@ import gym import numpy as np import torch as th -from stable_baselines3.common import callbacks -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common import callbacks, vec_env -GymEnv = Union[gym.Env, VecEnv] +GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] GymStepReturn = Tuple[GymObs, float, bool, Dict] TensorDict = Dict[str, th.Tensor] diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 0002788..42f08da 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -41,7 +41,7 @@ def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize] return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type -def is_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: +def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: """ Check if an environment is already wrapped by a given ``VecEnvWrapper``. diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index d208e12..c7bd7ac 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,6 +1,6 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import cloudpickle import gym @@ -139,6 +139,19 @@ class VecEnv(ABC): """ raise NotImplementedError() + @abstractmethod + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + """ + Check if environments are wrapped with a given wrapper. + + :param method_name: The name of the environment method to invoke. + :param indices: Indices of envs whose method to call + :param method_args: Any positional arguments to provide in the call + :param method_kwargs: Any keyword arguments to provide in the call + :return: True if the env is wrapped, False otherwise, for each env queried. + """ + raise NotImplementedError() + def step(self, actions: np.ndarray) -> VecEnvStepReturn: """ Step the environments with the given action @@ -280,6 +293,9 @@ class VecEnvWrapper(VecEnv): def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + return self.venv.env_is_wrapped(wrapper_class, indices=indices) + def __getattr__(self, name: str) -> Any: """Find attribute from wrapped venv(s) if this wrapper does not have it. Useful for accessing attributes from venvs which are wrapped with multiple wrappers diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index bd656e1..a1a2382 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,6 +1,6 @@ from collections import OrderedDict from copy import deepcopy -from typing import Any, Callable, List, Optional, Sequence, Union +from typing import Any, Callable, List, Optional, Sequence, Type, Union import gym import numpy as np @@ -112,6 +112,14 @@ class DummyVecEnv(VecEnv): target_envs = self._get_target_envs(indices) return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + """Check if worker environments are wrapped with a given wrapper""" + target_envs = self._get_target_envs(indices) + # Import here to avoid a circular import + from stable_baselines3.common import env_util + + return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs] + def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: indices = self._get_indices(indices) return [self.envs[i] for i in indices] diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 4a9cdc5..1050f3e 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,6 +1,6 @@ import multiprocessing as mp from collections import OrderedDict -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union import gym import numpy as np @@ -17,6 +17,9 @@ from stable_baselines3.common.vec_env.base_vec_env import ( def _worker( remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper ) -> None: + # Import here to avoid a circular import + from stable_baselines3.common.env_util import is_wrapped + parent_remote.close() env = env_fn_wrapper.var() while True: @@ -49,6 +52,8 @@ def _worker( remote.send(getattr(env, data)) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) + elif cmd == "is_wrapped": + remote.send(is_wrapped(env, data)) else: raise NotImplementedError(f"`{cmd}` is not implemented in the worker") except EOFError: @@ -170,6 +175,13 @@ class SubprocVecEnv(VecEnv): remote.send(("env_method", (method_name, method_args, method_kwargs))) return [remote.recv() for remote in target_remotes] + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + """Check if worker environments are wrapped with a given wrapper""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("is_wrapped", wrapper_class)) + return [remote.recv() for remote in target_remotes] + def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: """ Get the connection object needed to communicate with the wanted diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 2f5259b..144494a 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -33,7 +33,12 @@ def test_callbacks(tmp_path, model_class): callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1) eval_callback = EvalCallback( - eval_env, callback_on_new_best=callback_on_best, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100 + eval_env, + callback_on_new_best=callback_on_best, + best_model_save_path=log_folder, + log_path=log_folder, + eval_freq=100, + warn=False, ) # Equivalent to the `checkpoint_callback` # but here in an event-driven manner diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 7f85b75..b6dfd24 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -10,7 +10,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.identity_env import FakeImageEnv from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.utils import zip_strict -from stable_baselines3.common.vec_env import VecTransposeImage, is_wrapped +from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN]) @@ -29,7 +29,7 @@ def test_cnn(tmp_path, model_class): model = model_class("CnnPolicy", env, **kwargs).learn(250) # FakeImageEnv is channel last by default and should be wrapped - assert is_wrapped(model.get_env(), VecTransposeImage) + assert is_vecenv_wrapped(model.get_env(), VecTransposeImage) obs = env.reset() @@ -194,7 +194,7 @@ def test_channel_first_env(tmp_path): model = A2C("CnnPolicy", env, n_steps=100).learn(250) - assert not is_wrapped(model.get_env(), VecTransposeImage) + assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage) obs = env.reset() diff --git a/tests/test_identity.py b/tests/test_identity.py index 678f63c..fdde0d2 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -25,7 +25,7 @@ def test_discrete(model_class, env): model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps) - evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90) + evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) obs = env.reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) @@ -45,4 +45,4 @@ def test_continuous(model_class): model = model_class("MlpPolicy", env, **kwargs).learn(n_steps) - evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 98a1953..8b1feb3 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -48,4 +48,4 @@ def test_identity_spaces(model_class, env): model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64])) model.learn(total_timesteps=500) - evaluate_policy(model, env, n_eval_episodes=5) + evaluate_policy(model, env, n_eval_episodes=5, warn=False) diff --git a/tests/test_utils.py b/tests/test_utils.py index c30cb98..88aa76a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,7 +8,7 @@ import torch as th from stable_baselines3 import A2C from stable_baselines3.common.atari_wrappers import ClipRewardEnv -from stable_baselines3.common.env_util import make_atari_env, make_vec_env +from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise @@ -127,6 +127,103 @@ def test_evaluate_policy(): episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True) assert len(episode_rewards) == n_eval_episodes + # Test that warning is given about no monitor + eval_env = gym.make("Pendulum-v0") + with pytest.warns(UserWarning): + _ = evaluate_policy(model, eval_env, n_eval_episodes) + + +class ZeroRewardWrapper(gym.RewardWrapper): + def reward(self, reward): + return reward * 0 + + +class AlwaysDoneWrapper(gym.Wrapper): + # Pretends that environment only has single step for each + # episode. + def __init__(self, env): + super(AlwaysDoneWrapper, self).__init__(env) + self.last_obs = None + self.needs_reset = True + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self.needs_reset = done + self.last_obs = obs + return obs, reward, True, info + + def reset(self, **kwargs): + if self.needs_reset: + obs = self.env.reset(**kwargs) + self.last_obs = obs + self.needs_reset = False + return self.last_obs + + +@pytest.mark.parametrize("vec_env_class", [None, DummyVecEnv, SubprocVecEnv]) +def test_evaluate_policy_monitors(vec_env_class): + # Test that results are correct with monitor environments. + # Also test VecEnvs + n_eval_episodes = 2 + env_id = "CartPole-v0" + model = A2C("MlpPolicy", env_id, seed=0) + + def make_eval_env(with_monitor, wrapper_class=gym.Wrapper): + # Make eval environment with or without monitor in root, + # and additionally wrapped with another wrapper (after Monitor). + env = None + if vec_env_class is None: + # No vecenv, traditional env + env = gym.make(env_id) + if with_monitor: + env = Monitor(env) + env = wrapper_class(env) + else: + if with_monitor: + env = vec_env_class([lambda: wrapper_class(Monitor(gym.make(env_id)))]) + else: + env = vec_env_class([lambda: wrapper_class(gym.make(env_id))]) + return env + + # Test that evaluation with VecEnvs works as expected + eval_env = make_eval_env(with_monitor=True) + _ = evaluate_policy(model, eval_env, n_eval_episodes) + eval_env.close() + + # Warning without Monitor + eval_env = make_eval_env(with_monitor=False) + with pytest.warns(UserWarning): + _ = evaluate_policy(model, eval_env, n_eval_episodes) + eval_env.close() + + # Test that we gather correct reward with Monitor wrapper + # Sanity check that we get zero-reward without Monitor + eval_env = make_eval_env(with_monitor=False, wrapper_class=ZeroRewardWrapper) + average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes, warn=False) + assert average_reward == 0.0, "ZeroRewardWrapper wrapper for testing did not work" + eval_env.close() + + # Should get non-zero-rewards with Monitor (true reward) + eval_env = make_eval_env(with_monitor=True, wrapper_class=ZeroRewardWrapper) + average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes) + assert average_reward > 0.0, "evaluate_policy did not get reward from Monitor" + eval_env.close() + + # Test that we also track correct episode dones, not the wrapped ones. + # Sanity check that we get only one step per episode. + eval_env = make_eval_env(with_monitor=False, wrapper_class=AlwaysDoneWrapper) + episode_rewards, episode_lengths = evaluate_policy( + model, eval_env, n_eval_episodes, return_episode_rewards=True, warn=False + ) + assert all(map(lambda l: l == 1, episode_lengths)), "AlwaysDoneWrapper did not fix episode lengths to one" + eval_env.close() + + # Should get longer episodes with with Monitor (true episodes) + eval_env = make_eval_env(with_monitor=True, wrapper_class=AlwaysDoneWrapper) + episode_rewards, episode_lengths = evaluate_policy(model, eval_env, n_eval_episodes, return_episode_rewards=True) + assert all(map(lambda l: l > 1, episode_lengths)), "evaluate_policy did not get episode lengths from Monitor" + eval_env.close() + def test_vec_noise(): num_envs = 4 @@ -196,3 +293,16 @@ def test_cmd_util_rename(): """Test that importing cmd_util still works but raises warning""" with pytest.warns(FutureWarning): from stable_baselines3.common.cmd_util import make_vec_env # noqa: F401 + + +def test_is_wrapped(): + """Test that is_wrapped correctly detects wraps""" + env = gym.make("Pendulum-v0") + env = gym.Wrapper(env) + assert not is_wrapped(env, Monitor) + monitor_env = Monitor(env) + assert is_wrapped(monitor_env, Monitor) + env = gym.Wrapper(monitor_env) + assert is_wrapped(env, Monitor) + # Test that unwrap works as expected + assert unwrap_wrapper(env, Monitor) == monitor_env diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 5545ff4..9a4c118 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -7,6 +7,7 @@ import gym import numpy as np import pytest +from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize N_ENVS = 3 @@ -415,3 +416,27 @@ def test_framestack_vecenv(): # Test that it works with non-image envs when no channels_order is given vec_env = DummyVecEnv([make_non_image_env for _ in range(N_ENVS)]) vec_env = VecFrameStack(vec_env, n_stack=2) + + +def test_vec_env_is_wrapped(): + # Test is_wrapped call of subproc workers + def make_env(): + return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + + def make_monitored_env(): + return Monitor(CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))) + + # One with monitor, one without + vec_env = SubprocVecEnv([make_env, make_monitored_env]) + + assert vec_env.env_is_wrapped(Monitor) == [False, True] + + vec_env.close() + + # One with monitor, one without + vec_env = DummyVecEnv([make_env, make_monitored_env]) + + assert vec_env.env_is_wrapped(Monitor) == [False, True] + + vec_env = VecFrameStack(vec_env, n_stack=2) + assert vec_env.env_is_wrapped(Monitor) == [False, True] diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index a68e1b2..46c0c44 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -4,6 +4,7 @@ import pytest from gym import spaces from stable_baselines3 import HER, SAC, TD3 +from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.running_mean_std import RunningMeanStd from stable_baselines3.common.vec_env import ( DummyVecEnv, @@ -61,11 +62,11 @@ def allclose(obs_1, obs_2): def make_env(): - return gym.make(ENV_ID) + return Monitor(gym.make(ENV_ID)) def make_dict_env(): - return DummyDictEnv() + return Monitor(DummyDictEnv()) def check_rms_equal(rmsa, rmsb):