From f8ea2995cb21fca196424849315216be31b9cb2b Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 26 Jan 2025 11:42:57 +0100 Subject: [PATCH] Doc update: custom envs, IsaacLab, Brax and dm_control (#2072) * Add note about start!=0 for Discrete spaces * Update doc for IsaacLab and dm_control * Fix test due to rounding error --- docs/guide/custom_env.rst | 18 +++++++++ docs/guide/examples.rst | 53 ++++++++++++------------- docs/guide/sbx.rst | 1 + docs/misc/changelog.rst | 3 ++ stable_baselines3/common/env_checker.py | 3 +- tests/test_vec_normalize.py | 2 +- 6 files changed, 51 insertions(+), 29 deletions(-) diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index e075627..6c25ee8 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -24,6 +24,24 @@ That is to say, your environment must implement the following methods (and inher Under the hood, when a channel-last image is passed, SB3 uses a ``VecTransposeImage`` wrapper to re-order the channels. +.. note:: + + SB3 doesn't support ``Discrete`` and ``MultiDiscrete`` spaces with ``start!=0``. However, you can update your environment or use a wrapper to make your env compatible with SB3: + + .. code-block:: python + + import gymnasium as gym + + class ShiftWrapper(gym.Wrapper): + """Allow to use Discrete() action spaces with start!=0""" + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + assert isinstance(env.action_space, gym.spaces.Discrete) + self.action_space = gym.spaces.Discrete(env.action_space.n, start=0) + + def step(self, action: int): + return self.env.step(action + self.env.action_space.start) + .. code-block:: python diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index b60ff69..9b7391b 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -735,41 +735,40 @@ A2C policy gradient updates on the model. print(f"Best fitness: {top_candidates[0][1]:.2f}") -SB3 and ProcgenEnv ------------------- +SB3 with Isaac Lab, Brax, Procgen, EnvPool +------------------------------------------ -Some environments like `Procgen `_ already produce a vectorized -environment (see discussion in `issue #314 `_). In order to use it with SB3, you must wrap it in a ``VecMonitor`` wrapper which will also allow -to keep track of the agent progress. +Some massively parallel simulations such as `EnvPool `_, `Isaac Lab `_, `Brax `_ or `ProcGen `_ already produce a vectorized environment to speed up data collection (see discussion in `issue #314 `_). + +To use SB3 with these tools, you need to wrap the env with tool-specific ``VecEnvWrapper`` that pre-processes the data for SB3, +you can find links to some of these wrappers in `issue #772 `_. + +- Isaac Lab wrapper: `link `__ +- Brax: `link `__ +- EnvPool: `link `__ + + +SB3 with DeepMind Control (dm_control) +-------------------------------------- + +If you want to use SB3 with `dm_control `_, you need to use two wrappers (one from `shimmy `_, one pre-built one) to convert it to a Gymnasium compatible environment: .. code-block:: python - from procgen import ProcgenEnv + import shimmy + import stable_baselines3 as sb3 + from dm_control import suite + from gymnasium.wrappers import FlattenObservation - from stable_baselines3 import PPO - from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor + # Available envs: + # suite._DOMAINS and suite.dog.SUITE - # ProcgenEnv is already vectorized - venv = ProcgenEnv(num_envs=2, env_name="starpilot") + env = suite.load(domain_name="dog", task_name="run") + gym_env = FlattenObservation(shimmy.DmControlCompatibilityV0(env)) - # To use only part of the observation: - # venv = VecExtractDictObs(venv, "rgb") + model = sb3.PPO("MlpPolicy", gym_env, verbose=1) + model.learn(10_000, progress_bar=True) - # Wrap with a VecMonitor to collect stats and avoid errors - venv = VecMonitor(venv=venv) - - model = PPO("MultiInputPolicy", venv, verbose=1) - model.learn(10_000) - - -SB3 with EnvPool or Isaac Gym ------------------------------ - -Just like Procgen (see above), `EnvPool `_ and `Isaac Gym `_ accelerate the environment by -already providing a vectorized implementation. - -To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3, -you can find links to those wrappers in `issue #772 `_. Record a Video diff --git a/docs/guide/sbx.rst b/docs/guide/sbx.rst index ed5369e..7d69fe2 100644 --- a/docs/guide/sbx.rst +++ b/docs/guide/sbx.rst @@ -18,6 +18,7 @@ Implemented algorithms: - Twin Delayed DDPG (TD3) - Deep Deterministic Policy Gradient (DDPG) - Batch Normalization in Deep Reinforcement Learning (CrossQ) +- Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa) As SBX follows SB3 API, it is also compatible with the `RL Zoo `_. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 04cc296..a9eef94 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -42,6 +42,9 @@ Documentation: - Add FootstepNet Envs to the project page (@cgaspard3333) - Added FRASA to the project page (@MarcDcls) - Fixed atari example (@chrisgao99) +- Add a note about ``Discrete`` action spaces with ``start!=0`` +- Update doc for massively parallel simulators (Isaac Lab, Brax, ...) +- Add dm_control example Release 2.4.1 (2024-12-20) -------------------------- diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 0310bcf..05dce2e 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -37,7 +37,8 @@ def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", warnings.warn( f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) " "is not supported by Stable-Baselines3. " - f"You can use a wrapper or update your {space_type} space." + "You can use a wrapper (see https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html) " + f"or update your {space_type} space." ) diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 3db5fcd..57b9816 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -315,7 +315,7 @@ def test_get_original(): assert not np.array_equal(orig_obs, obs) assert not np.array_equal(orig_rewards, rewards) np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs) - np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards) + np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards, atol=1e-6) def test_get_original_dict():