mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
* add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
import numpy as np
|
|
from gym import spaces
|
|
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
|
|
|
|
|
|
class DictObsVecEnv:
|
|
"""Custom Environment that produces observation in a dictionary like the procgen env"""
|
|
|
|
metadata = {"render.modes": ["human"]}
|
|
|
|
def __init__(self):
|
|
self.num_envs = 4
|
|
self.action_space = spaces.Discrete(2)
|
|
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
|
|
|
|
def step_async(self, actions):
|
|
self.actions = actions
|
|
|
|
def step_wait(self):
|
|
return (
|
|
{"rgb": np.zeros((self.num_envs, 86, 86))},
|
|
np.zeros((self.num_envs,)),
|
|
np.zeros((self.num_envs,), dtype=bool),
|
|
[{} for _ in range(self.num_envs)],
|
|
)
|
|
|
|
def reset(self):
|
|
return {"rgb": np.zeros((self.num_envs, 86, 86))}
|
|
|
|
def render(self, mode="human", close=False):
|
|
pass
|
|
|
|
|
|
def test_extract_dict_obs():
|
|
"""Test VecExtractDictObs"""
|
|
|
|
env = DictObsVecEnv()
|
|
env = VecExtractDictObs(env, "rgb")
|
|
assert env.reset().shape == (4, 86, 86)
|
|
|
|
|
|
def test_vec_with_ppo():
|
|
"""
|
|
Test the `VecExtractDictObs` with PPO
|
|
"""
|
|
env = DictObsVecEnv()
|
|
env = VecExtractDictObs(env, "rgb")
|
|
monitor_env = VecMonitor(env)
|
|
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
|
|
model.learn(total_timesteps=250)
|