Fix render bug for vec env wrappers (#1525)

* Fix render bug for vec env wrappers

* Fix tests and update changelog

* Better fix, backward compatible

* remove render_mode from VecEnv init

* Make DictObsVecEnv inherit from VecEnv

* format

* Fix env_is_wrapped

* try/except getting render mode ( (https://github.com/DLR-RM/stable-baselines3/pull/1525#discussion_r1206888921)

* update version

* Fix env_is_wrapped in test_vec_extract_dict

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
This commit is contained in:
Antonin RAFFIN 2023-06-07 16:20:40 +02:00 committed by GitHub
parent 32778ddc94
commit ffe26ccf95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 52 additions and 21 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 2.0.0a12 (WIP)
Release 2.0.0a13 (WIP)
--------------------------
**Gymnasium support**
@ -64,7 +64,7 @@ Others:
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
- Improve type annotation of wrappers
- Tests envs are now checked too
- Added render test for ``VecEnv``
- Added render test for ``VecEnv`` and ``VecEnvWrapper``
- Update issue templates and env info saved with the model
- Changed ``seed()`` method return type from ``List`` to ``Sequence``
- Updated env checker doc and requirements for tuple spaces/goal envs

View file

@ -61,16 +61,24 @@ class VecEnv(ABC):
num_envs: int,
observation_space: spaces.Space,
action_space: spaces.Space,
render_mode: Optional[str] = None,
):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
self.render_mode = render_mode
# store info returned by the reset method
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
# seeds to be used in the next call to env.reset()
self._seeds: List[Optional[int]] = [None for _ in range(num_envs)]
try:
render_modes = self.get_attr("render_mode")
except AttributeError:
warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.")
render_modes = [None for _ in range(num_envs)]
assert all(
render_mode == render_modes[0] for render_mode in render_modes
), "render_mode mode should be the same for all environments"
self.render_mode = render_modes[0]
def _reset_seeds(self) -> None:
"""
@ -313,15 +321,13 @@ class VecEnvWrapper(VecEnv):
venv: VecEnv,
observation_space: Optional[spaces.Space] = None,
action_space: Optional[spaces.Space] = None,
render_mode: Optional[str] = None,
):
self.venv = venv
VecEnv.__init__(
self,
super().__init__(
num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space,
render_mode=render_mode,
)
self.class_attributes = dict(inspect.getmembers(self.__class__))

View file

@ -39,7 +39,7 @@ class DummyVecEnv(VecEnv):
"Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information."
)
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space, env.render_mode)
super().__init__(len(env_fns), env.observation_space, env.action_space)
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)

View file

@ -119,9 +119,7 @@ class SubprocVecEnv(VecEnv):
self.remotes[0].send(("get_spaces", None))
observation_space, action_space = self.remotes[0].recv()
self.remotes[0].send(("get_attr", "render_mode"))
render_mode = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space, render_mode)
super().__init__(len(env_fns), observation_space, action_space)
def step_async(self, actions: np.ndarray) -> None:
for remote, action in zip(self.remotes, actions):

View file

@ -1 +1 @@
2.0.0a12
2.0.0a13

View file

@ -586,4 +586,13 @@ def test_render(vec_env_class):
for _ in range(10):
vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)])
vec_env.render()
# Check that it still works with vec env wrapper
vec_env = VecFrameStack(vec_env, 2)
vec_env.render()
assert vec_env.render_mode == "rgb_array"
vec_env = VecNormalize(vec_env)
assert vec_env.render_mode == "rgb_array"
vec_env.render()
vec_env.close()

View file

@ -2,13 +2,13 @@ import numpy as np
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
from stable_baselines3.common.vec_env import VecEnv, VecExtractDictObs, VecMonitor
class DictObsVecEnv:
class DictObsVecEnv(VecEnv):
"""Custom Environment that produces observation in a dictionary like the procgen env"""
metadata = {"render.modes": ["human"]}
metadata = {"render_modes": ["human"]}
def __init__(self):
self.num_envs = 4
@ -16,6 +16,7 @@ class DictObsVecEnv:
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
self.n_steps = 0
self.max_steps = 5
self.render_mode = None
def step_async(self, actions):
self.actions = actions
@ -25,25 +26,42 @@ class DictObsVecEnv:
done = self.n_steps >= self.max_steps
if done:
infos = [
{"terminal_observation": {"rgb": np.zeros((86, 86))}, "TimeLimit.truncated": True}
{"terminal_observation": {"rgb": np.zeros((86, 86), dtype=np.float32)}, "TimeLimit.truncated": True}
for _ in range(self.num_envs)
]
else:
infos = []
return (
{"rgb": np.zeros((self.num_envs, 86, 86))},
np.zeros((self.num_envs,)),
{"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)},
np.zeros((self.num_envs,), dtype=np.float32),
np.ones((self.num_envs,), dtype=bool) * done,
infos,
)
def reset(self):
self.n_steps = 0
return {"rgb": np.zeros((self.num_envs, 86, 86))}
return {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)}
def render(self, close=False):
def render(self, mode=""):
pass
def get_attr(self, attr_name, indices=None):
indices = range(self.num_envs) if indices is None else indices
return [getattr(self, attr_name) for _ in indices]
def close(self):
pass
def env_is_wrapped(self, wrapper_class, indices=None):
indices = range(self.num_envs) if indices is None else indices
return [False for _ in indices]
def env_method(self):
raise NotImplementedError # not used in the test
def set_attr(self, attr_name, value, indices=None) -> None:
raise NotImplementedError # not used in the test
def test_extract_dict_obs():
"""Test VecExtractDictObs"""