mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
* Fix render bug for vec env wrappers * Fix tests and update changelog * Better fix, backward compatible * remove render_mode from VecEnv init * Make DictObsVecEnv inherit from VecEnv * format * Fix env_is_wrapped * try/except getting render mode ( (https://github.com/DLR-RM/stable-baselines3/pull/1525#discussion_r1206888921) * update version * Fix env_is_wrapped in test_vec_extract_dict --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
92 lines
2.9 KiB
Python
92 lines
2.9 KiB
Python
import numpy as np
|
|
from gymnasium import spaces
|
|
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import VecEnv, VecExtractDictObs, VecMonitor
|
|
|
|
|
|
class DictObsVecEnv(VecEnv):
|
|
"""Custom Environment that produces observation in a dictionary like the procgen env"""
|
|
|
|
metadata = {"render_modes": ["human"]}
|
|
|
|
def __init__(self):
|
|
self.num_envs = 4
|
|
self.action_space = spaces.Discrete(2)
|
|
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
|
|
self.n_steps = 0
|
|
self.max_steps = 5
|
|
self.render_mode = None
|
|
|
|
def step_async(self, actions):
|
|
self.actions = actions
|
|
|
|
def step_wait(self):
|
|
self.n_steps += 1
|
|
done = self.n_steps >= self.max_steps
|
|
if done:
|
|
infos = [
|
|
{"terminal_observation": {"rgb": np.zeros((86, 86), dtype=np.float32)}, "TimeLimit.truncated": True}
|
|
for _ in range(self.num_envs)
|
|
]
|
|
else:
|
|
infos = []
|
|
return (
|
|
{"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)},
|
|
np.zeros((self.num_envs,), dtype=np.float32),
|
|
np.ones((self.num_envs,), dtype=bool) * done,
|
|
infos,
|
|
)
|
|
|
|
def reset(self):
|
|
self.n_steps = 0
|
|
return {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)}
|
|
|
|
def render(self, mode=""):
|
|
pass
|
|
|
|
def get_attr(self, attr_name, indices=None):
|
|
indices = range(self.num_envs) if indices is None else indices
|
|
return [getattr(self, attr_name) for _ in indices]
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
def env_is_wrapped(self, wrapper_class, indices=None):
|
|
indices = range(self.num_envs) if indices is None else indices
|
|
return [False for _ in indices]
|
|
|
|
def env_method(self):
|
|
raise NotImplementedError # not used in the test
|
|
|
|
def set_attr(self, attr_name, value, indices=None) -> None:
|
|
raise NotImplementedError # not used in the test
|
|
|
|
|
|
def test_extract_dict_obs():
|
|
"""Test VecExtractDictObs"""
|
|
|
|
env = DictObsVecEnv()
|
|
env = VecExtractDictObs(env, "rgb")
|
|
assert env.reset().shape == (4, 86, 86)
|
|
|
|
for _ in range(10):
|
|
obs, _, dones, infos = env.step([env.action_space.sample() for _ in range(env.num_envs)])
|
|
assert obs.shape == (4, 86, 86)
|
|
for idx, info in enumerate(infos):
|
|
if "terminal_observation" in info:
|
|
assert dones[idx]
|
|
assert info["terminal_observation"].shape == (86, 86)
|
|
else:
|
|
assert not dones[idx]
|
|
|
|
|
|
def test_vec_with_ppo():
|
|
"""
|
|
Test the `VecExtractDictObs` with PPO
|
|
"""
|
|
env = DictObsVecEnv()
|
|
env = VecExtractDictObs(env, "rgb")
|
|
monitor_env = VecMonitor(env)
|
|
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
|
|
model.learn(total_timesteps=250)
|