mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Fix render bug for vec env wrappers (#1525)
* Fix render bug for vec env wrappers * Fix tests and update changelog * Better fix, backward compatible * remove render_mode from VecEnv init * Make DictObsVecEnv inherit from VecEnv * format * Fix env_is_wrapped * try/except getting render mode ( (https://github.com/DLR-RM/stable-baselines3/pull/1525#discussion_r1206888921) * update version * Fix env_is_wrapped in test_vec_extract_dict --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
This commit is contained in:
parent
32778ddc94
commit
ffe26ccf95
7 changed files with 52 additions and 21 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.0.0a12 (WIP)
|
||||
Release 2.0.0a13 (WIP)
|
||||
--------------------------
|
||||
|
||||
**Gymnasium support**
|
||||
|
|
@ -64,7 +64,7 @@ Others:
|
|||
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
|
||||
- Improve type annotation of wrappers
|
||||
- Tests envs are now checked too
|
||||
- Added render test for ``VecEnv``
|
||||
- Added render test for ``VecEnv`` and ``VecEnvWrapper``
|
||||
- Update issue templates and env info saved with the model
|
||||
- Changed ``seed()`` method return type from ``List`` to ``Sequence``
|
||||
- Updated env checker doc and requirements for tuple spaces/goal envs
|
||||
|
|
|
|||
|
|
@ -61,16 +61,24 @@ class VecEnv(ABC):
|
|||
num_envs: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
render_mode: Optional[str] = None,
|
||||
):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.render_mode = render_mode
|
||||
# store info returned by the reset method
|
||||
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
|
||||
# seeds to be used in the next call to env.reset()
|
||||
self._seeds: List[Optional[int]] = [None for _ in range(num_envs)]
|
||||
try:
|
||||
render_modes = self.get_attr("render_mode")
|
||||
except AttributeError:
|
||||
warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.")
|
||||
render_modes = [None for _ in range(num_envs)]
|
||||
|
||||
assert all(
|
||||
render_mode == render_modes[0] for render_mode in render_modes
|
||||
), "render_mode mode should be the same for all environments"
|
||||
self.render_mode = render_modes[0]
|
||||
|
||||
def _reset_seeds(self) -> None:
|
||||
"""
|
||||
|
|
@ -313,15 +321,13 @@ class VecEnvWrapper(VecEnv):
|
|||
venv: VecEnv,
|
||||
observation_space: Optional[spaces.Space] = None,
|
||||
action_space: Optional[spaces.Space] = None,
|
||||
render_mode: Optional[str] = None,
|
||||
):
|
||||
self.venv = venv
|
||||
VecEnv.__init__(
|
||||
self,
|
||||
|
||||
super().__init__(
|
||||
num_envs=venv.num_envs,
|
||||
observation_space=observation_space or venv.observation_space,
|
||||
action_space=action_space or venv.action_space,
|
||||
render_mode=render_mode,
|
||||
)
|
||||
self.class_attributes = dict(inspect.getmembers(self.__class__))
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class DummyVecEnv(VecEnv):
|
|||
"Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information."
|
||||
)
|
||||
env = self.envs[0]
|
||||
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space, env.render_mode)
|
||||
super().__init__(len(env_fns), env.observation_space, env.action_space)
|
||||
obs_space = env.observation_space
|
||||
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
||||
|
||||
|
|
|
|||
|
|
@ -119,9 +119,7 @@ class SubprocVecEnv(VecEnv):
|
|||
self.remotes[0].send(("get_spaces", None))
|
||||
observation_space, action_space = self.remotes[0].recv()
|
||||
|
||||
self.remotes[0].send(("get_attr", "render_mode"))
|
||||
render_mode = self.remotes[0].recv()
|
||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space, render_mode)
|
||||
super().__init__(len(env_fns), observation_space, action_space)
|
||||
|
||||
def step_async(self, actions: np.ndarray) -> None:
|
||||
for remote, action in zip(self.remotes, actions):
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.0.0a12
|
||||
2.0.0a13
|
||||
|
|
@ -586,4 +586,13 @@ def test_render(vec_env_class):
|
|||
for _ in range(10):
|
||||
vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)])
|
||||
vec_env.render()
|
||||
|
||||
# Check that it still works with vec env wrapper
|
||||
vec_env = VecFrameStack(vec_env, 2)
|
||||
vec_env.render()
|
||||
assert vec_env.render_mode == "rgb_array"
|
||||
vec_env = VecNormalize(vec_env)
|
||||
assert vec_env.render_mode == "rgb_array"
|
||||
vec_env.render()
|
||||
|
||||
vec_env.close()
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ import numpy as np
|
|||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
|
||||
from stable_baselines3.common.vec_env import VecEnv, VecExtractDictObs, VecMonitor
|
||||
|
||||
|
||||
class DictObsVecEnv:
|
||||
class DictObsVecEnv(VecEnv):
|
||||
"""Custom Environment that produces observation in a dictionary like the procgen env"""
|
||||
|
||||
metadata = {"render.modes": ["human"]}
|
||||
metadata = {"render_modes": ["human"]}
|
||||
|
||||
def __init__(self):
|
||||
self.num_envs = 4
|
||||
|
|
@ -16,6 +16,7 @@ class DictObsVecEnv:
|
|||
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
|
||||
self.n_steps = 0
|
||||
self.max_steps = 5
|
||||
self.render_mode = None
|
||||
|
||||
def step_async(self, actions):
|
||||
self.actions = actions
|
||||
|
|
@ -25,25 +26,42 @@ class DictObsVecEnv:
|
|||
done = self.n_steps >= self.max_steps
|
||||
if done:
|
||||
infos = [
|
||||
{"terminal_observation": {"rgb": np.zeros((86, 86))}, "TimeLimit.truncated": True}
|
||||
{"terminal_observation": {"rgb": np.zeros((86, 86), dtype=np.float32)}, "TimeLimit.truncated": True}
|
||||
for _ in range(self.num_envs)
|
||||
]
|
||||
else:
|
||||
infos = []
|
||||
return (
|
||||
{"rgb": np.zeros((self.num_envs, 86, 86))},
|
||||
np.zeros((self.num_envs,)),
|
||||
{"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)},
|
||||
np.zeros((self.num_envs,), dtype=np.float32),
|
||||
np.ones((self.num_envs,), dtype=bool) * done,
|
||||
infos,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.n_steps = 0
|
||||
return {"rgb": np.zeros((self.num_envs, 86, 86))}
|
||||
return {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)}
|
||||
|
||||
def render(self, close=False):
|
||||
def render(self, mode=""):
|
||||
pass
|
||||
|
||||
def get_attr(self, attr_name, indices=None):
|
||||
indices = range(self.num_envs) if indices is None else indices
|
||||
return [getattr(self, attr_name) for _ in indices]
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def env_is_wrapped(self, wrapper_class, indices=None):
|
||||
indices = range(self.num_envs) if indices is None else indices
|
||||
return [False for _ in indices]
|
||||
|
||||
def env_method(self):
|
||||
raise NotImplementedError # not used in the test
|
||||
|
||||
def set_attr(self, attr_name, value, indices=None) -> None:
|
||||
raise NotImplementedError # not used in the test
|
||||
|
||||
|
||||
def test_extract_dict_obs():
|
||||
"""Test VecExtractDictObs"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue