diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a3116c2..8c40ef8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a12 (WIP) +Release 2.0.0a13 (WIP) -------------------------- **Gymnasium support** @@ -64,7 +64,7 @@ Others: - Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks - Improve type annotation of wrappers - Tests envs are now checked too -- Added render test for ``VecEnv`` +- Added render test for ``VecEnv`` and ``VecEnvWrapper`` - Update issue templates and env info saved with the model - Changed ``seed()`` method return type from ``List`` to ``Sequence`` - Updated env checker doc and requirements for tuple spaces/goal envs diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 7d7cfc2..0c93f63 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -61,16 +61,24 @@ class VecEnv(ABC): num_envs: int, observation_space: spaces.Space, action_space: spaces.Space, - render_mode: Optional[str] = None, ): self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space - self.render_mode = render_mode # store info returned by the reset method self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)] # seeds to be used in the next call to env.reset() self._seeds: List[Optional[int]] = [None for _ in range(num_envs)] + try: + render_modes = self.get_attr("render_mode") + except AttributeError: + warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.") + render_modes = [None for _ in range(num_envs)] + + assert all( + render_mode == render_modes[0] for render_mode in render_modes + ), "render_mode mode should be the same for all environments" + self.render_mode = render_modes[0] def _reset_seeds(self) -> None: """ @@ -313,15 +321,13 @@ class VecEnvWrapper(VecEnv): venv: VecEnv, observation_space: Optional[spaces.Space] = None, action_space: Optional[spaces.Space] = None, - render_mode: Optional[str] = None, ): self.venv = venv - VecEnv.__init__( - self, + + super().__init__( num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, action_space=action_space or venv.action_space, - render_mode=render_mode, ) self.class_attributes = dict(inspect.getmembers(self.__class__)) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 29b4d63..2908981 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -39,7 +39,7 @@ class DummyVecEnv(VecEnv): "Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information." ) env = self.envs[0] - VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space, env.render_mode) + super().__init__(len(env_fns), env.observation_space, env.action_space) obs_space = env.observation_space self.keys, shapes, dtypes = obs_space_info(obs_space) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 4d46954..cc8ffdb 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -119,9 +119,7 @@ class SubprocVecEnv(VecEnv): self.remotes[0].send(("get_spaces", None)) observation_space, action_space = self.remotes[0].recv() - self.remotes[0].send(("get_attr", "render_mode")) - render_mode = self.remotes[0].recv() - VecEnv.__init__(self, len(env_fns), observation_space, action_space, render_mode) + super().__init__(len(env_fns), observation_space, action_space) def step_async(self, actions: np.ndarray) -> None: for remote, action in zip(self.remotes, actions): diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index d8aacfa..a80d3c5 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.0.0a12 +2.0.0a13 \ No newline at end of file diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 36d848a..61740c4 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -586,4 +586,13 @@ def test_render(vec_env_class): for _ in range(10): vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)]) vec_env.render() + + # Check that it still works with vec env wrapper + vec_env = VecFrameStack(vec_env, 2) + vec_env.render() + assert vec_env.render_mode == "rgb_array" + vec_env = VecNormalize(vec_env) + assert vec_env.render_mode == "rgb_array" + vec_env.render() + vec_env.close() diff --git a/tests/test_vec_extract_dict_obs.py b/tests/test_vec_extract_dict_obs.py index 8c8dccd..e7be7e4 100644 --- a/tests/test_vec_extract_dict_obs.py +++ b/tests/test_vec_extract_dict_obs.py @@ -2,13 +2,13 @@ import numpy as np from gymnasium import spaces from stable_baselines3 import PPO -from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor +from stable_baselines3.common.vec_env import VecEnv, VecExtractDictObs, VecMonitor -class DictObsVecEnv: +class DictObsVecEnv(VecEnv): """Custom Environment that produces observation in a dictionary like the procgen env""" - metadata = {"render.modes": ["human"]} + metadata = {"render_modes": ["human"]} def __init__(self): self.num_envs = 4 @@ -16,6 +16,7 @@ class DictObsVecEnv: self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)}) self.n_steps = 0 self.max_steps = 5 + self.render_mode = None def step_async(self, actions): self.actions = actions @@ -25,25 +26,42 @@ class DictObsVecEnv: done = self.n_steps >= self.max_steps if done: infos = [ - {"terminal_observation": {"rgb": np.zeros((86, 86))}, "TimeLimit.truncated": True} + {"terminal_observation": {"rgb": np.zeros((86, 86), dtype=np.float32)}, "TimeLimit.truncated": True} for _ in range(self.num_envs) ] else: infos = [] return ( - {"rgb": np.zeros((self.num_envs, 86, 86))}, - np.zeros((self.num_envs,)), + {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)}, + np.zeros((self.num_envs,), dtype=np.float32), np.ones((self.num_envs,), dtype=bool) * done, infos, ) def reset(self): self.n_steps = 0 - return {"rgb": np.zeros((self.num_envs, 86, 86))} + return {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)} - def render(self, close=False): + def render(self, mode=""): pass + def get_attr(self, attr_name, indices=None): + indices = range(self.num_envs) if indices is None else indices + return [getattr(self, attr_name) for _ in indices] + + def close(self): + pass + + def env_is_wrapped(self, wrapper_class, indices=None): + indices = range(self.num_envs) if indices is None else indices + return [False for _ in indices] + + def env_method(self): + raise NotImplementedError # not used in the test + + def set_attr(self, attr_name, value, indices=None) -> None: + raise NotImplementedError # not used in the test + def test_extract_dict_obs(): """Test VecExtractDictObs"""