mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Update outdated custom env doc (#1490)
* Update outdated custom env doc * fix render_mode and term/trunc/reset_info * gym -> gymnasium --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
parent
9cebedc89f
commit
fd0cd82339
14 changed files with 62 additions and 68 deletions
|
|
@ -139,7 +139,7 @@ for i in range(1000):
|
|||
env.close()
|
||||
```
|
||||
|
||||
Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
|
||||
Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
|
||||
|
||||
```python
|
||||
from stable_baselines3 import PPO
|
||||
|
|
|
|||
|
|
@ -3,18 +3,19 @@
|
|||
Using Custom Environments
|
||||
==========================
|
||||
|
||||
To use the RL baselines with custom environments, they just need to follow the *gym* interface.
|
||||
That is to say, your environment must implement the following methods (and inherits from OpenAI Gym Class):
|
||||
To use the RL baselines with custom environments, they just need to follow the *gymnasium* `interface <https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#sphx-glr-tutorials-gymnasium-basics-environment-creation-py>`_.
|
||||
That is to say, your environment must implement the following methods (and inherits from Gym Class):
|
||||
|
||||
|
||||
.. note::
|
||||
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255].
|
||||
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies.
|
||||
Images can be either channel-first or channel-last.
|
||||
|
||||
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255].
|
||||
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies.
|
||||
Images can be either channel-first or channel-last.
|
||||
|
||||
If you want to use ``CnnPolicy`` or ``MultiInputPolicy`` with image-like observation (3D tensor) that are already normalized, you must pass ``normalize_images=False``
|
||||
to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``)
|
||||
and make sure your image is in the **channel-first** format.
|
||||
to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``)
|
||||
and make sure your image is in the **channel-first** format.
|
||||
|
||||
|
||||
.. note::
|
||||
|
|
@ -34,7 +35,7 @@ That is to say, your environment must implement the following methods (and inher
|
|||
class CustomEnv(gym.Env):
|
||||
"""Custom Environment that follows gym interface."""
|
||||
|
||||
metadata = {"render.modes": ["human"]}
|
||||
metadata = {"render_modes": ["human"], "render_fps": 30}
|
||||
|
||||
def __init__(self, arg1, arg2, ...):
|
||||
super().__init__()
|
||||
|
|
@ -48,11 +49,11 @@ That is to say, your environment must implement the following methods (and inher
|
|||
|
||||
def step(self, action):
|
||||
...
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed=None, options=None):
|
||||
...
|
||||
return observation # reward, done, info can't be included
|
||||
return observation, info
|
||||
|
||||
def render(self):
|
||||
...
|
||||
|
|
@ -81,11 +82,11 @@ To check that your environment follows the Gym interface that SB3 supports, plea
|
|||
# It will check your custom environment and output additional warnings if needed
|
||||
check_env(env)
|
||||
|
||||
Gym also have its own `env checker <https://www.gymlibrary.ml/content/api/#checking-api-conformity>`_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
|
||||
Gymnasium also have its own `env checker <https://gymnasium.farama.org/api/utils/#gymnasium.utils.env_checker.check_env>`_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
|
||||
|
||||
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface.
|
||||
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/5_custom_gym_env.ipynb>`_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface.
|
||||
|
||||
Alternatively, you may look at OpenAI Gym `built-in environments <https://www.gymlibrary.ml/>`_. However, the readers are cautioned as per OpenAI Gym `official wiki <https://github.com/openai/gym/wiki/FAQ>`_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them.
|
||||
Alternatively, you may look at Gymnasium `built-in environments <https://gymnasium.farama.org>`_.
|
||||
|
||||
Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env):
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ In the following example, we will train, save and load a DQN model on the Lunar
|
|||
|
||||
|
||||
# Create environment
|
||||
env = gym.make("LunarLander-v2")
|
||||
env = gym.make("LunarLander-v2", render_mode="rgb_array")
|
||||
|
||||
# Instantiate the agent
|
||||
model = DQN("MlpPolicy", env, verbose=1)
|
||||
|
|
@ -99,7 +99,7 @@ In the following example, we will train, save and load a DQN model on the Lunar
|
|||
for i in range(1000):
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, rewards, dones, info = vec_env.step(action)
|
||||
vec_env.render()
|
||||
vec_env.render("human")
|
||||
|
||||
|
||||
Multiprocessing: Unleashing the Power of Vectorized Environments
|
||||
|
|
@ -116,7 +116,6 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
|
|||
.. code-block:: python
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||
|
|
@ -512,6 +511,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
|
|||
# Load saved model
|
||||
# Because it needs access to `env.compute_reward()`
|
||||
# HER must be loaded with the env
|
||||
env = gym.make("parking-v0", render_mode="human") # Change the render mode
|
||||
model = SAC.load("her_sac_highway", env=env)
|
||||
|
||||
obs, info = env.reset()
|
||||
|
|
@ -521,7 +521,6 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
|
|||
for _ in range(100):
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
env.render()
|
||||
episode_reward += reward
|
||||
if terminated or truncated or info.get("is_success", False):
|
||||
print("Reward:", episode_reward, "Success?", info.get("is_success", False))
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
env = gym.make("CartPole-v1")
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||
|
||||
model = A2C("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
|
@ -30,7 +30,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
|
|||
for i in range(1000):
|
||||
action, _state = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = vec_env.step(action)
|
||||
vec_env.render()
|
||||
vec_env.render("human")
|
||||
# VecEnv resets automatically
|
||||
# if done:
|
||||
# obs = vec_env.reset()
|
||||
|
|
@ -40,8 +40,8 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
|
|||
You can find explanations about the logger output and names in the :ref:`Logger <logger>` section.
|
||||
|
||||
|
||||
Or just train a model with a one liner if
|
||||
`the environment is registered in Gym <https://github.com/openai/gym/wiki/Environments>`_ and if
|
||||
Or just train a model with a one line if
|
||||
`the environment is registered in Gymnasium <https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs>`_ and if
|
||||
the policy is registered:
|
||||
|
||||
.. code-block:: python
|
||||
|
|
|
|||
|
|
@ -210,14 +210,14 @@ If you want to quickly try a random agent on your environment, you can also do:
|
|||
.. code-block:: python
|
||||
|
||||
env = YourEnv()
|
||||
obs = env.reset()
|
||||
obs, info = env.reset()
|
||||
n_steps = 10
|
||||
for _ in range(n_steps):
|
||||
# Random action
|
||||
action = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(action)
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if done:
|
||||
obs = env.reset()
|
||||
obs, info = env.reset()
|
||||
|
||||
|
||||
**Why should I normalize the action space?**
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ Documentation:
|
|||
- Make it more explicit when using ``VecEnv`` vs Gym env
|
||||
- Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn)
|
||||
- Added ``EvalCallback`` example (@sidney-tio)
|
||||
- Update custom env documentation
|
||||
|
||||
|
||||
Release 1.8.0 (2023-04-07)
|
||||
|
|
|
|||
|
|
@ -53,15 +53,13 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
# Parallel environments
|
||||
env = make_vec_env("CartPole-v1", n_envs=4)
|
||||
vec_env = make_vec_env("CartPole-v1", n_envs=4)
|
||||
|
||||
model = A2C("MlpPolicy", env, verbose=1)
|
||||
model = A2C("MlpPolicy", vec_env, verbose=1)
|
||||
model.learn(total_timesteps=25000)
|
||||
model.save("a2c_cartpole")
|
||||
|
||||
|
|
@ -69,11 +67,11 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
|
|||
|
||||
model = A2C.load("a2c_cartpole")
|
||||
|
||||
obs = env.reset()
|
||||
obs = vec_env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
obs, rewards, dones, info = vec_env.step(action)
|
||||
vec_env.render("human")
|
||||
|
||||
|
||||
.. note::
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
from stable_baselines3 import DDPG
|
||||
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||
|
||||
env = gym.make("Pendulum-v1")
|
||||
env = gym.make("Pendulum-v1", render_mode="rgb_array")
|
||||
|
||||
# The noise objects for DDPG
|
||||
n_actions = env.action_space.shape[-1]
|
||||
|
|
@ -76,17 +76,17 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=10)
|
||||
model.save("ddpg_pendulum")
|
||||
env = model.get_env()
|
||||
vec_env = model.get_env()
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = DDPG.load("ddpg_pendulum")
|
||||
|
||||
obs = env.reset()
|
||||
obs = vec_env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
obs, rewards, dones, info = vec_env.step(action)
|
||||
env.render("human")
|
||||
|
||||
Results
|
||||
-------
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
|
||||
from stable_baselines3 import DQN
|
||||
|
||||
env = gym.make("CartPole-v1")
|
||||
env = gym.make("CartPole-v1", render_mode="human")
|
||||
|
||||
model = DQN("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
|
|
@ -70,13 +70,12 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
|
||||
model = DQN.load("dqn_cartpole")
|
||||
|
||||
obs = env.reset()
|
||||
obs, info = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
|
||||
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
|
||||
from stable_baselines3.common.envs import BitFlippingEnv
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
model_class = DQN # works also with SAC, DDPG and TD3
|
||||
N_BITS = 15
|
||||
|
|
@ -96,13 +95,12 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
# HER must be loaded with the env
|
||||
model = model_class.load("./her_bit_env", env=env)
|
||||
|
||||
obs = env.reset()
|
||||
obs, info = env.reset()
|
||||
for _ in range(100):
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, _ = env.step(action)
|
||||
|
||||
if done:
|
||||
obs = env.reset()
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
|
|
|
|||
|
|
@ -71,9 +71,9 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
|
|||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
# Parallel environments
|
||||
env = make_vec_env("CartPole-v1", n_envs=4)
|
||||
vec_env = make_vec_env("CartPole-v1", n_envs=4)
|
||||
|
||||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model = PPO("MlpPolicy", vec_env, verbose=1)
|
||||
model.learn(total_timesteps=25000)
|
||||
model.save("ppo_cartpole")
|
||||
|
||||
|
|
@ -81,11 +81,11 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
|
|||
|
||||
model = PPO.load("ppo_cartpole")
|
||||
|
||||
obs = env.reset()
|
||||
obs = vec_env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
obs, rewards, dones, info = vec_env.step(action)
|
||||
vec_env.render("human")
|
||||
|
||||
|
||||
Results
|
||||
|
|
|
|||
|
|
@ -69,11 +69,10 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
.. code-block:: python
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
|
||||
env = gym.make("Pendulum-v1")
|
||||
env = gym.make("Pendulum-v1", render_mode="human")
|
||||
|
||||
model = SAC("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
|
|
@ -83,13 +82,12 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
|
||||
model = SAC.load("sac_pendulum")
|
||||
|
||||
obs = env.reset()
|
||||
obs, info = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
from stable_baselines3 import TD3
|
||||
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||
|
||||
env = gym.make("Pendulum-v1")
|
||||
env = gym.make("Pendulum-v1", render_mode="rgb_array")
|
||||
|
||||
# The noise objects for TD3
|
||||
n_actions = env.action_space.shape[-1]
|
||||
|
|
@ -76,17 +76,17 @@ This example is only to demonstrate the use of the library and its functions, an
|
|||
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=10)
|
||||
model.save("td3_pendulum")
|
||||
env = model.get_env()
|
||||
vec_env = model.get_env()
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = TD3.load("td3_pendulum")
|
||||
|
||||
obs = env.reset()
|
||||
obs = vec_env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
obs, rewards, dones, info = vec_env.step(action)
|
||||
vec_env.render("human")
|
||||
|
||||
Results
|
||||
-------
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -149,7 +149,7 @@ setup(
|
|||
url="https://github.com/DLR-RM/stable-baselines3",
|
||||
author_email="antonin.raffin@dlr.de",
|
||||
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
|
||||
"gym openai stable baselines toolbox python data-science",
|
||||
"gymnasium gym openai stable baselines toolbox python data-science",
|
||||
license="MIT",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
|
|
|
|||
Loading…
Reference in a new issue