stable-baselines3/tests/test_vec_extract_dict_obs.py
Antonin RAFFIN ffe26ccf95
Fix render bug for vec env wrappers (#1525)
* Fix render bug for vec env wrappers

* Fix tests and update changelog

* Better fix, backward compatible

* remove render_mode from VecEnv init

* Make DictObsVecEnv inherit from VecEnv

* format

* Fix env_is_wrapped

* try/except getting render mode ( (https://github.com/DLR-RM/stable-baselines3/pull/1525#discussion_r1206888921)

* update version

* Fix env_is_wrapped in test_vec_extract_dict

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-06-07 16:20:40 +02:00

92 lines
2.9 KiB
Python

import numpy as np
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecEnv, VecExtractDictObs, VecMonitor
class DictObsVecEnv(VecEnv):
"""Custom Environment that produces observation in a dictionary like the procgen env"""
metadata = {"render_modes": ["human"]}
def __init__(self):
self.num_envs = 4
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
self.n_steps = 0
self.max_steps = 5
self.render_mode = None
def step_async(self, actions):
self.actions = actions
def step_wait(self):
self.n_steps += 1
done = self.n_steps >= self.max_steps
if done:
infos = [
{"terminal_observation": {"rgb": np.zeros((86, 86), dtype=np.float32)}, "TimeLimit.truncated": True}
for _ in range(self.num_envs)
]
else:
infos = []
return (
{"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)},
np.zeros((self.num_envs,), dtype=np.float32),
np.ones((self.num_envs,), dtype=bool) * done,
infos,
)
def reset(self):
self.n_steps = 0
return {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)}
def render(self, mode=""):
pass
def get_attr(self, attr_name, indices=None):
indices = range(self.num_envs) if indices is None else indices
return [getattr(self, attr_name) for _ in indices]
def close(self):
pass
def env_is_wrapped(self, wrapper_class, indices=None):
indices = range(self.num_envs) if indices is None else indices
return [False for _ in indices]
def env_method(self):
raise NotImplementedError # not used in the test
def set_attr(self, attr_name, value, indices=None) -> None:
raise NotImplementedError # not used in the test
def test_extract_dict_obs():
"""Test VecExtractDictObs"""
env = DictObsVecEnv()
env = VecExtractDictObs(env, "rgb")
assert env.reset().shape == (4, 86, 86)
for _ in range(10):
obs, _, dones, infos = env.step([env.action_space.sample() for _ in range(env.num_envs)])
assert obs.shape == (4, 86, 86)
for idx, info in enumerate(infos):
if "terminal_observation" in info:
assert dones[idx]
assert info["terminal_observation"].shape == (86, 86)
else:
assert not dones[idx]
def test_vec_with_ppo():
"""
Test the `VecExtractDictObs` with PPO
"""
env = DictObsVecEnv()
env = VecExtractDictObs(env, "rgb")
monitor_env = VecMonitor(env)
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
model.learn(total_timesteps=250)