Use Monitor episode reward/length for evaluate_policy (#220)

* Update evaluate_policy to use monitor data if available

* Update documentation

* Cleaning up

* Remove unnecessary typing trickery

* Update doc

* Rename is_wrapped to clarify it is for vecenvs

* Add is_wrapped for regular envs

* Add is_wrapped call for subprocvecenv and update code for circular imports

* Move new functions back to env_util and fix imports

* Update changelog

* Clarify evaluate_policy docs

* Add tests for wrapped modifying episode lengths

* Fix tests

* Update changelog

* Minor edits

* Add warn switch to evaluate_policy and update tests

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Anssi 2020-11-16 12:52:28 +02:00 committed by GitHub
parent c74509ae9d
commit 18d10dbf42
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 305 additions and 34 deletions

View file

@ -79,6 +79,9 @@ In the following example, we will train, save and load a DQN model on the Lunar
model = DQN.load("dqn_lunar")
# Evaluate the agent
# NOTE: If you use wrappers with your environment that modify rewards,
# this will be reflected here. To evaluate with original rewards,
# wrap environment in a "Monitor" wrapper before other wrappers.
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
# Enjoy trained agent

View file

@ -17,7 +17,7 @@ TL;DR
1. Read about RL and Stable Baselines3
2. Do quantitative experiments and hyperparameter tuning if needed
3. Evaluate the performance using a separate test environment
3. Evaluate the performance using a separate test environment (remember to check wrappers!)
4. For better performance, increase the training budget
@ -68,18 +68,24 @@ Other method, like ``TRPO`` or ``PPO`` make use of a *trust region* to minimize
How to evaluate an RL algorithm?
--------------------------------
.. note::
Pay attention to environment wrappers when evaluating your agent and comparing results to others' results. Modifications to episode rewards
or lengths may also affect evaluation results which may not be desirable. Check ``evaluate_policy`` helper function in :ref:`Evaluation Helper <eval>` section.
Because most algorithms use exploration noise during training, you need a separate test environment to evaluate the performance
of your agent at a given time. It is recommended to periodically evaluate your agent for ``n`` test episodes (``n`` is usually between 5 and 20)
and average the reward per episode to have a good estimate.
.. note::
We provide an ``EvalCallback`` for doing such evaluation. You can read more about it in the :ref:`Callbacks <callbacks>` section.
As some policy are stochastic by default (e.g. A2C or PPO), you should also try to set `deterministic=True` when calling the `.predict()` method,
this frequently leads to better performance.
Looking at the training curve (episode reward function of the timesteps) is a good proxy but underestimates the agent true performance.
.. note::
We provide an ``EvalCallback`` for doing such evaluation. You can read more about it in the :ref:`Callbacks <callbacks>` section.

View file

@ -8,7 +8,10 @@ Pre-Release 0.11.0a0 (WIP)
Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``evaluate_policy`` now returns rewards/episode lengths from a ``Monitor`` wrapper if one is present,
this allows to return the unnormalized reward in the case of Atari games for instance.
- Renamed ``common.vec_env.is_wrapped`` to ``common.vec_env.is_vecenv_wrapped`` to avoid confusion
with the new ``is_wrapped()`` helper
New Features:
^^^^^^^^^^^^^
@ -16,6 +19,10 @@ New Features:
automatic check for image spaces.
- ``VecFrameStack`` now has a ``channels_order`` argument to tell if observations should be stacked
on the first or last observation dimension (originally always stacked on last).
- Added ``common.env_util.is_wrapped`` and ``common.env_util.unwrap_wrapper`` functions for checking/unwrapping
an environment for specific wrapper.
- Added ``env_is_wrapped()`` method for ``VecEnv`` to check if its environments are wrapped
with given Gym wrappers.
Bug Fixes:
^^^^^^^^^^

View file

@ -31,7 +31,7 @@ from stable_baselines3.common.vec_env import (
VecEnv,
VecNormalize,
VecTransposeImage,
is_wrapped,
is_vecenv_wrapped,
unwrap_vec_normalize,
)
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
@ -178,7 +178,7 @@ class BaseAlgorithm(ABC):
if (
is_image_space(env.observation_space)
and not is_wrapped(env, VecTransposeImage)
and not is_vecenv_wrapped(env, VecTransposeImage)
and not is_image_space_channels_first(env.observation_space)
):
if verbose >= 1:

View file

@ -276,6 +276,8 @@ class EvalCallback(EventCallback):
:param deterministic: Whether to render or not the environment during evaluation
:param render: Whether to render or not the environment during evaluation
:param verbose:
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
wrapped with a Monitor wrapper)
"""
def __init__(
@ -289,6 +291,7 @@ class EvalCallback(EventCallback):
deterministic: bool = True,
render: bool = False,
verbose: int = 1,
warn: bool = True,
):
super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose)
self.n_eval_episodes = n_eval_episodes
@ -297,6 +300,7 @@ class EvalCallback(EventCallback):
self.last_mean_reward = -np.inf
self.deterministic = deterministic
self.render = render
self.warn = warn
# Convert to VecEnv for consistency
if not isinstance(eval_env, VecEnv):
@ -339,6 +343,7 @@ class EvalCallback(EventCallback):
render=self.render,
deterministic=self.deterministic,
return_episode_rewards=True,
warn=self.warn,
)
if self.log_path is not None:

View file

@ -8,6 +8,33 @@ from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env: Environment to unwrap
:param wrapper_class: Wrapper to look for
:return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
"""
env_tmp = env
while isinstance(env_tmp, gym.Wrapper):
if isinstance(env_tmp, wrapper_class):
return env_tmp
env_tmp = env_tmp.env
return None
def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
"""
Check if a given environment has been wrapped with a given wrapper.
:param env: Environment to check
:param wrapper_class: Wrapper class to look for
:return: True if environment has been wrapped with ``wrapper_class``.
"""
return unwrap_wrapper(env, wrapper_class) is not None
def make_vec_env(
env_id: Union[str, Type[gym.Env]],
n_envs: int = 1,

View file

@ -1,3 +1,4 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import gym
@ -16,11 +17,20 @@ def evaluate_policy(
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
warn: bool = True,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
This is made to work only with one env.
.. note::
If environment has not been wrapped with ``Monitor`` wrapper, reward and
episode lengths are counted as it appears with ``env.step`` calls. If
the environment contains wrappers that modify rewards or episode lengths
(e.g. reward scaling, early episode reset), these will affect the evaluation
results as well. You can avoid this by wrapping environment with ``Monitor``
wrapper before anything else.
:param model: The RL agent you want to evaluate.
:param env: The gym environment. In the case of a ``VecEnv``
this must contain only one environment.
@ -31,33 +41,70 @@ def evaluate_policy(
called after each step. Gets locals() and globals() passed as parameters.
:param reward_threshold: Minimum expected reward per episode,
this will raise an error if the performance is not met
:param return_episode_rewards: If True, a list of reward per episode
will be returned instead of the mean.
:return: Mean reward per episode, std of reward per episode
returns ([float], [int]) when ``return_episode_rewards`` is True
:param return_episode_rewards: If True, a list of rewards and episde lengths
per episode will be returned instead of the mean.
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
evaluation environment.
:return: Mean reward per episode, std of reward per episode.
Returns ([float], [int]) when ``return_episode_rewards`` is True, first
list containing per-episode rewards and second containing per-episode lengths
(in number of steps).
"""
is_monitor_wrapped = False
# Avoid circular import
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.monitor import Monitor
if isinstance(env, VecEnv):
assert env.num_envs == 1, "You must pass only one environment when using this function"
is_monitor_wrapped = env.env_is_wrapped(Monitor)[0]
else:
is_monitor_wrapped = is_wrapped(env, Monitor)
if not is_monitor_wrapped and warn:
warnings.warn(
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
"Consider wrapping environment first with ``Monitor`` wrapper.",
UserWarning,
)
episode_rewards, episode_lengths = [], []
for i in range(n_eval_episodes):
# Avoid double reset, as VecEnv are reset automatically
if not isinstance(env, VecEnv) or i == 0:
not_reseted = True
while len(episode_rewards) < n_eval_episodes:
# Number of loops here might differ from true episodes
# played, if underlying wrappers modify episode lengths.
# Avoid double reset, as VecEnv are reset automatically.
if not isinstance(env, VecEnv) or not_reseted:
obs = env.reset()
not_reseted = False
done, state = False, None
episode_reward = 0.0
episode_length = 0
while not done:
action, state = model.predict(obs, state=state, deterministic=deterministic)
obs, reward, done, _info = env.step(action)
obs, reward, done, info = env.step(action)
episode_reward += reward
if callback is not None:
callback(locals(), globals())
episode_length += 1
if render:
env.render()
episode_rewards.append(episode_reward)
episode_lengths.append(episode_length)
if is_monitor_wrapped:
# Do not trust "done" with episode endings.
# Remove vecenv stacking (if any)
if isinstance(env, VecEnv):
info = info[0]
if "episode" in info.keys():
# Monitor wrapper includes "episode" key in info if environment
# has been wrapped with it. Use those rewards instead.
episode_rewards.append(info["episode"]["r"])
episode_lengths.append(info["episode"]["l"])
else:
episode_rewards.append(episode_reward)
episode_lengths.append(episode_length)
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
if reward_threshold is not None:

View file

@ -6,10 +6,9 @@ import gym
import numpy as np
import torch as th
from stable_baselines3.common import callbacks
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common import callbacks, vec_env
GymEnv = Union[gym.Env, VecEnv]
GymEnv = Union[gym.Env, vec_env.VecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
GymStepReturn = Tuple[GymObs, float, bool, Dict]
TensorDict = Dict[str, th.Tensor]

View file

@ -41,7 +41,7 @@ def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
def is_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
"""
Check if an environment is already wrapped by a given ``VecEnvWrapper``.

View file

@ -1,6 +1,6 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union
import cloudpickle
import gym
@ -139,6 +139,19 @@ class VecEnv(ABC):
"""
raise NotImplementedError()
@abstractmethod
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""
Check if environments are wrapped with a given wrapper.
:param method_name: The name of the environment method to invoke.
:param indices: Indices of envs whose method to call
:param method_args: Any positional arguments to provide in the call
:param method_kwargs: Any keyword arguments to provide in the call
:return: True if the env is wrapped, False otherwise, for each env queried.
"""
raise NotImplementedError()
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
"""
Step the environments with the given action
@ -280,6 +293,9 @@ class VecEnvWrapper(VecEnv):
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
return self.venv.env_is_wrapped(wrapper_class, indices=indices)
def __getattr__(self, name: str) -> Any:
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
Useful for accessing attributes from venvs which are wrapped with multiple wrappers

View file

@ -1,6 +1,6 @@
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, List, Optional, Sequence, Union
from typing import Any, Callable, List, Optional, Sequence, Type, Union
import gym
import numpy as np
@ -112,6 +112,14 @@ class DummyVecEnv(VecEnv):
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_envs = self._get_target_envs(indices)
# Import here to avoid a circular import
from stable_baselines3.common import env_util
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]

View file

@ -1,6 +1,6 @@
import multiprocessing as mp
from collections import OrderedDict
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
import gym
import numpy as np
@ -17,6 +17,9 @@ from stable_baselines3.common.vec_env.base_vec_env import (
def _worker(
remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper
) -> None:
# Import here to avoid a circular import
from stable_baselines3.common.env_util import is_wrapped
parent_remote.close()
env = env_fn_wrapper.var()
while True:
@ -49,6 +52,8 @@ def _worker(
remote.send(getattr(env, data))
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1]))
elif cmd == "is_wrapped":
remote.send(is_wrapped(env, data))
else:
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
@ -170,6 +175,13 @@ class SubprocVecEnv(VecEnv):
remote.send(("env_method", (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in target_remotes]
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("is_wrapped", wrapper_class))
return [remote.recv() for remote in target_remotes]
def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
"""
Get the connection object needed to communicate with the wanted

View file

@ -33,7 +33,12 @@ def test_callbacks(tmp_path, model_class):
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
eval_callback = EvalCallback(
eval_env, callback_on_new_best=callback_on_best, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100
eval_env,
callback_on_new_best=callback_on_best,
best_model_save_path=log_folder,
log_path=log_folder,
eval_freq=100,
warn=False,
)
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner

View file

@ -10,7 +10,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import VecTransposeImage, is_wrapped
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
@ -29,7 +29,7 @@ def test_cnn(tmp_path, model_class):
model = model_class("CnnPolicy", env, **kwargs).learn(250)
# FakeImageEnv is channel last by default and should be wrapped
assert is_wrapped(model.get_env(), VecTransposeImage)
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
obs = env.reset()
@ -194,7 +194,7 @@ def test_channel_first_env(tmp_path):
model = A2C("CnnPolicy", env, n_steps=100).learn(250)
assert not is_wrapped(model.get_env(), VecTransposeImage)
assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage)
obs = env.reset()

View file

@ -25,7 +25,7 @@ def test_discrete(model_class, env):
model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps)
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90)
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
obs = env.reset()
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
@ -45,4 +45,4 @@ def test_continuous(model_class):
model = model_class("MlpPolicy", env, **kwargs).learn(n_steps)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)

View file

@ -48,4 +48,4 @@ def test_identity_spaces(model_class, env):
model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64]))
model.learn(total_timesteps=500)
evaluate_policy(model, env, n_eval_episodes=5)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)

View file

@ -8,7 +8,7 @@ import torch as th
from stable_baselines3 import A2C
from stable_baselines3.common.atari_wrappers import ClipRewardEnv
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
@ -127,6 +127,103 @@ def test_evaluate_policy():
episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
assert len(episode_rewards) == n_eval_episodes
# Test that warning is given about no monitor
eval_env = gym.make("Pendulum-v0")
with pytest.warns(UserWarning):
_ = evaluate_policy(model, eval_env, n_eval_episodes)
class ZeroRewardWrapper(gym.RewardWrapper):
def reward(self, reward):
return reward * 0
class AlwaysDoneWrapper(gym.Wrapper):
# Pretends that environment only has single step for each
# episode.
def __init__(self, env):
super(AlwaysDoneWrapper, self).__init__(env)
self.last_obs = None
self.needs_reset = True
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.needs_reset = done
self.last_obs = obs
return obs, reward, True, info
def reset(self, **kwargs):
if self.needs_reset:
obs = self.env.reset(**kwargs)
self.last_obs = obs
self.needs_reset = False
return self.last_obs
@pytest.mark.parametrize("vec_env_class", [None, DummyVecEnv, SubprocVecEnv])
def test_evaluate_policy_monitors(vec_env_class):
# Test that results are correct with monitor environments.
# Also test VecEnvs
n_eval_episodes = 2
env_id = "CartPole-v0"
model = A2C("MlpPolicy", env_id, seed=0)
def make_eval_env(with_monitor, wrapper_class=gym.Wrapper):
# Make eval environment with or without monitor in root,
# and additionally wrapped with another wrapper (after Monitor).
env = None
if vec_env_class is None:
# No vecenv, traditional env
env = gym.make(env_id)
if with_monitor:
env = Monitor(env)
env = wrapper_class(env)
else:
if with_monitor:
env = vec_env_class([lambda: wrapper_class(Monitor(gym.make(env_id)))])
else:
env = vec_env_class([lambda: wrapper_class(gym.make(env_id))])
return env
# Test that evaluation with VecEnvs works as expected
eval_env = make_eval_env(with_monitor=True)
_ = evaluate_policy(model, eval_env, n_eval_episodes)
eval_env.close()
# Warning without Monitor
eval_env = make_eval_env(with_monitor=False)
with pytest.warns(UserWarning):
_ = evaluate_policy(model, eval_env, n_eval_episodes)
eval_env.close()
# Test that we gather correct reward with Monitor wrapper
# Sanity check that we get zero-reward without Monitor
eval_env = make_eval_env(with_monitor=False, wrapper_class=ZeroRewardWrapper)
average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes, warn=False)
assert average_reward == 0.0, "ZeroRewardWrapper wrapper for testing did not work"
eval_env.close()
# Should get non-zero-rewards with Monitor (true reward)
eval_env = make_eval_env(with_monitor=True, wrapper_class=ZeroRewardWrapper)
average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes)
assert average_reward > 0.0, "evaluate_policy did not get reward from Monitor"
eval_env.close()
# Test that we also track correct episode dones, not the wrapped ones.
# Sanity check that we get only one step per episode.
eval_env = make_eval_env(with_monitor=False, wrapper_class=AlwaysDoneWrapper)
episode_rewards, episode_lengths = evaluate_policy(
model, eval_env, n_eval_episodes, return_episode_rewards=True, warn=False
)
assert all(map(lambda l: l == 1, episode_lengths)), "AlwaysDoneWrapper did not fix episode lengths to one"
eval_env.close()
# Should get longer episodes with with Monitor (true episodes)
eval_env = make_eval_env(with_monitor=True, wrapper_class=AlwaysDoneWrapper)
episode_rewards, episode_lengths = evaluate_policy(model, eval_env, n_eval_episodes, return_episode_rewards=True)
assert all(map(lambda l: l > 1, episode_lengths)), "evaluate_policy did not get episode lengths from Monitor"
eval_env.close()
def test_vec_noise():
num_envs = 4
@ -196,3 +293,16 @@ def test_cmd_util_rename():
"""Test that importing cmd_util still works but raises warning"""
with pytest.warns(FutureWarning):
from stable_baselines3.common.cmd_util import make_vec_env # noqa: F401
def test_is_wrapped():
"""Test that is_wrapped correctly detects wraps"""
env = gym.make("Pendulum-v0")
env = gym.Wrapper(env)
assert not is_wrapped(env, Monitor)
monitor_env = Monitor(env)
assert is_wrapped(monitor_env, Monitor)
env = gym.Wrapper(monitor_env)
assert is_wrapped(env, Monitor)
# Test that unwrap works as expected
assert unwrap_wrapper(env, Monitor) == monitor_env

View file

@ -7,6 +7,7 @@ import gym
import numpy as np
import pytest
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
N_ENVS = 3
@ -415,3 +416,27 @@ def test_framestack_vecenv():
# Test that it works with non-image envs when no channels_order is given
vec_env = DummyVecEnv([make_non_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2)
def test_vec_env_is_wrapped():
# Test is_wrapped call of subproc workers
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
def make_monitored_env():
return Monitor(CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))))
# One with monitor, one without
vec_env = SubprocVecEnv([make_env, make_monitored_env])
assert vec_env.env_is_wrapped(Monitor) == [False, True]
vec_env.close()
# One with monitor, one without
vec_env = DummyVecEnv([make_env, make_monitored_env])
assert vec_env.env_is_wrapped(Monitor) == [False, True]
vec_env = VecFrameStack(vec_env, n_stack=2)
assert vec_env.env_is_wrapped(Monitor) == [False, True]

View file

@ -4,6 +4,7 @@ import pytest
from gym import spaces
from stable_baselines3 import HER, SAC, TD3
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env import (
DummyVecEnv,
@ -61,11 +62,11 @@ def allclose(obs_1, obs_2):
def make_env():
return gym.make(ENV_ID)
return Monitor(gym.make(ENV_ID))
def make_dict_env():
return DummyDictEnv()
return Monitor(DummyDictEnv())
def check_rms_equal(rmsa, rmsb):