mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-25 22:35:14 +00:00
Doc update: custom envs, IsaacLab, Brax and dm_control (#2072)
* Add note about start!=0 for Discrete spaces * Update doc for IsaacLab and dm_control * Fix test due to rounding error
This commit is contained in:
parent
d055a2e2af
commit
f8ea2995cb
6 changed files with 51 additions and 29 deletions
|
|
@ -24,6 +24,24 @@ That is to say, your environment must implement the following methods (and inher
|
|||
Under the hood, when a channel-last image is passed, SB3 uses a ``VecTransposeImage`` wrapper to re-order the channels.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
SB3 doesn't support ``Discrete`` and ``MultiDiscrete`` spaces with ``start!=0``. However, you can update your environment or use a wrapper to make your env compatible with SB3:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
class ShiftWrapper(gym.Wrapper):
|
||||
"""Allow to use Discrete() action spaces with start!=0"""
|
||||
def __init__(self, env: gym.Env) -> None:
|
||||
super().__init__(env)
|
||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||
self.action_space = gym.spaces.Discrete(env.action_space.n, start=0)
|
||||
|
||||
def step(self, action: int):
|
||||
return self.env.step(action + self.env.action_space.start)
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
|||
|
|
@ -735,41 +735,40 @@ A2C policy gradient updates on the model.
|
|||
print(f"Best fitness: {top_candidates[0][1]:.2f}")
|
||||
|
||||
|
||||
SB3 and ProcgenEnv
|
||||
------------------
|
||||
SB3 with Isaac Lab, Brax, Procgen, EnvPool
|
||||
------------------------------------------
|
||||
|
||||
Some environments like `Procgen <https://github.com/openai/procgen>`_ already produce a vectorized
|
||||
environment (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_). In order to use it with SB3, you must wrap it in a ``VecMonitor`` wrapper which will also allow
|
||||
to keep track of the agent progress.
|
||||
Some massively parallel simulations such as `EnvPool <https://github.com/sail-sg/envpool>`_, `Isaac Lab <https://github.com/isaac-sim/IsaacLab>`_, `Brax <https://github.com/google/brax>`_ or `ProcGen <https://github.com/Farama-Foundation/Procgen2>`_ already produce a vectorized environment to speed up data collection (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_).
|
||||
|
||||
To use SB3 with these tools, you need to wrap the env with tool-specific ``VecEnvWrapper`` that pre-processes the data for SB3,
|
||||
you can find links to some of these wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.
|
||||
|
||||
- Isaac Lab wrapper: `link <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/sb3.py>`__
|
||||
- Brax: `link <https://gist.github.com/araffin/a7a576ec1453e74d9bb93120918ef7e7>`__
|
||||
- EnvPool: `link <https://github.com/sail-sg/envpool/blob/main/examples/sb3_examples/ppo.py>`__
|
||||
|
||||
|
||||
SB3 with DeepMind Control (dm_control)
|
||||
--------------------------------------
|
||||
|
||||
If you want to use SB3 with `dm_control <https://github.com/google-deepmind/dm_control>`_, you need to use two wrappers (one from `shimmy <https://github.com/Farama-Foundation/Shimmy>`_, one pre-built one) to convert it to a Gymnasium compatible environment:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from procgen import ProcgenEnv
|
||||
import shimmy
|
||||
import stable_baselines3 as sb3
|
||||
from dm_control import suite
|
||||
from gymnasium.wrappers import FlattenObservation
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
|
||||
# Available envs:
|
||||
# suite._DOMAINS and suite.dog.SUITE
|
||||
|
||||
# ProcgenEnv is already vectorized
|
||||
venv = ProcgenEnv(num_envs=2, env_name="starpilot")
|
||||
env = suite.load(domain_name="dog", task_name="run")
|
||||
gym_env = FlattenObservation(shimmy.DmControlCompatibilityV0(env))
|
||||
|
||||
# To use only part of the observation:
|
||||
# venv = VecExtractDictObs(venv, "rgb")
|
||||
model = sb3.PPO("MlpPolicy", gym_env, verbose=1)
|
||||
model.learn(10_000, progress_bar=True)
|
||||
|
||||
# Wrap with a VecMonitor to collect stats and avoid errors
|
||||
venv = VecMonitor(venv=venv)
|
||||
|
||||
model = PPO("MultiInputPolicy", venv, verbose=1)
|
||||
model.learn(10_000)
|
||||
|
||||
|
||||
SB3 with EnvPool or Isaac Gym
|
||||
-----------------------------
|
||||
|
||||
Just like Procgen (see above), `EnvPool <https://github.com/sail-sg/envpool>`_ and `Isaac Gym <https://github.com/NVIDIA-Omniverse/IsaacGymEnvs>`_ accelerate the environment by
|
||||
already providing a vectorized implementation.
|
||||
|
||||
To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3,
|
||||
you can find links to those wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.
|
||||
|
||||
|
||||
Record a Video
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ Implemented algorithms:
|
|||
- Twin Delayed DDPG (TD3)
|
||||
- Deep Deterministic Policy Gradient (DDPG)
|
||||
- Batch Normalization in Deep Reinforcement Learning (CrossQ)
|
||||
- Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)
|
||||
|
||||
|
||||
As SBX follows SB3 API, it is also compatible with the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ Documentation:
|
|||
- Add FootstepNet Envs to the project page (@cgaspard3333)
|
||||
- Added FRASA to the project page (@MarcDcls)
|
||||
- Fixed atari example (@chrisgao99)
|
||||
- Add a note about ``Discrete`` action spaces with ``start!=0``
|
||||
- Update doc for massively parallel simulators (Isaac Lab, Brax, ...)
|
||||
- Add dm_control example
|
||||
|
||||
Release 2.4.1 (2024-12-20)
|
||||
--------------------------
|
||||
|
|
|
|||
|
|
@ -37,7 +37,8 @@ def _check_non_zero_start(space: spaces.Space, space_type: str = "observation",
|
|||
warnings.warn(
|
||||
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
|
||||
"is not supported by Stable-Baselines3. "
|
||||
f"You can use a wrapper or update your {space_type} space."
|
||||
"You can use a wrapper (see https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html) "
|
||||
f"or update your {space_type} space."
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -315,7 +315,7 @@ def test_get_original():
|
|||
assert not np.array_equal(orig_obs, obs)
|
||||
assert not np.array_equal(orig_rewards, rewards)
|
||||
np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs)
|
||||
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)
|
||||
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards, atol=1e-6)
|
||||
|
||||
|
||||
def test_get_original_dict():
|
||||
|
|
|
|||
Loading…
Reference in a new issue