Update outdated custom env doc (#1490)

* Update outdated custom env doc

* fix render_mode and term/trunc/reset_info

* gym -> gymnasium

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
Antonin RAFFIN 2023-05-08 13:48:26 +02:00 committed by GitHub
parent 9cebedc89f
commit fd0cd82339
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 62 additions and 68 deletions

View file

@ -139,7 +139,7 @@ for i in range(1000):
env.close()
```
Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
```python
from stable_baselines3 import PPO

View file

@ -3,18 +3,19 @@
Using Custom Environments
==========================
To use the RL baselines with custom environments, they just need to follow the *gym* interface.
That is to say, your environment must implement the following methods (and inherits from OpenAI Gym Class):
To use the RL baselines with custom environments, they just need to follow the *gymnasium* `interface <https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#sphx-glr-tutorials-gymnasium-basics-environment-creation-py>`_.
That is to say, your environment must implement the following methods (and inherits from Gym Class):
.. note::
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255].
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies.
Images can be either channel-first or channel-last.
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255].
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies.
Images can be either channel-first or channel-last.
If you want to use ``CnnPolicy`` or ``MultiInputPolicy`` with image-like observation (3D tensor) that are already normalized, you must pass ``normalize_images=False``
to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``)
and make sure your image is in the **channel-first** format.
to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``)
and make sure your image is in the **channel-first** format.
.. note::
@ -34,7 +35,7 @@ That is to say, your environment must implement the following methods (and inher
class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface."""
metadata = {"render.modes": ["human"]}
metadata = {"render_modes": ["human"], "render_fps": 30}
def __init__(self, arg1, arg2, ...):
super().__init__()
@ -48,11 +49,11 @@ That is to say, your environment must implement the following methods (and inher
def step(self, action):
...
return observation, reward, done, info
return observation, reward, terminated, truncated, info
def reset(self):
def reset(self, seed=None, options=None):
...
return observation # reward, done, info can't be included
return observation, info
def render(self):
...
@ -81,11 +82,11 @@ To check that your environment follows the Gym interface that SB3 supports, plea
# It will check your custom environment and output additional warnings if needed
check_env(env)
Gym also have its own `env checker <https://www.gymlibrary.ml/content/api/#checking-api-conformity>`_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
Gymnasium also have its own `env checker <https://gymnasium.farama.org/api/utils/#gymnasium.utils.env_checker.check_env>`_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface.
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/5_custom_gym_env.ipynb>`_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface.
Alternatively, you may look at OpenAI Gym `built-in environments <https://www.gymlibrary.ml/>`_. However, the readers are cautioned as per OpenAI Gym `official wiki <https://github.com/openai/gym/wiki/FAQ>`_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them.
Alternatively, you may look at Gymnasium `built-in environments <https://gymnasium.farama.org>`_.
Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env):

View file

@ -71,7 +71,7 @@ In the following example, we will train, save and load a DQN model on the Lunar
# Create environment
env = gym.make("LunarLander-v2")
env = gym.make("LunarLander-v2", render_mode="rgb_array")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
@ -99,7 +99,7 @@ In the following example, we will train, save and load a DQN model on the Lunar
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render()
vec_env.render("human")
Multiprocessing: Unleashing the Power of Vectorized Environments
@ -116,7 +116,6 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
.. code-block:: python
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
@ -512,6 +511,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
# Load saved model
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
env = gym.make("parking-v0", render_mode="human") # Change the render mode
model = SAC.load("her_sac_highway", env=env)
obs, info = env.reset()
@ -521,7 +521,6 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
env.render()
episode_reward += reward
if terminated or truncated or info.get("is_success", False):
print("Reward:", episode_reward, "Success?", info.get("is_success", False))

View file

@ -20,7 +20,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
from stable_baselines3 import A2C
env = gym.make("CartPole-v1")
env = gym.make("CartPole-v1", render_mode="rgb_array")
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
@ -30,7 +30,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
for i in range(1000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
vec_env.render("human")
# VecEnv resets automatically
# if done:
# obs = vec_env.reset()
@ -40,8 +40,8 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
You can find explanations about the logger output and names in the :ref:`Logger <logger>` section.
Or just train a model with a one liner if
`the environment is registered in Gym <https://github.com/openai/gym/wiki/Environments>`_ and if
Or just train a model with a one line if
`the environment is registered in Gymnasium <https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs>`_ and if
the policy is registered:
.. code-block:: python

View file

@ -210,14 +210,14 @@ If you want to quickly try a random agent on your environment, you can also do:
.. code-block:: python
env = YourEnv()
obs = env.reset()
obs, info = env.reset()
n_steps = 10
for _ in range(n_steps):
# Random action
action = env.action_space.sample()
obs, reward, done, info = env.step(action)
obs, reward, terminated, truncated, info = env.step(action)
if done:
obs = env.reset()
obs, info = env.reset()
**Why should I normalize the action space?**

View file

@ -70,6 +70,7 @@ Documentation:
- Make it more explicit when using ``VecEnv`` vs Gym env
- Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn)
- Added ``EvalCallback`` example (@sidney-tio)
- Update custom env documentation
Release 1.8.0 (2023-04-07)

View file

@ -53,15 +53,13 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
.. code-block:: python
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
env = make_vec_env("CartPole-v1", n_envs=4)
vec_env = make_vec_env("CartPole-v1", n_envs=4)
model = A2C("MlpPolicy", env, verbose=1)
model = A2C("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")
@ -69,11 +67,11 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
model = A2C.load("a2c_cartpole")
obs = env.reset()
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
.. note::

View file

@ -67,7 +67,7 @@ This example is only to demonstrate the use of the library and its functions, an
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make("Pendulum-v1")
env = gym.make("Pendulum-v1", render_mode="rgb_array")
# The noise objects for DDPG
n_actions = env.action_space.shape[-1]
@ -76,17 +76,17 @@ This example is only to demonstrate the use of the library and its functions, an
model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("ddpg_pendulum")
env = model.get_env()
vec_env = model.get_env()
del model # remove to demonstrate saving and loading
model = DDPG.load("ddpg_pendulum")
obs = env.reset()
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
obs, rewards, dones, info = vec_env.step(action)
env.render("human")
Results
-------

View file

@ -60,7 +60,7 @@ This example is only to demonstrate the use of the library and its functions, an
from stable_baselines3 import DQN
env = gym.make("CartPole-v1")
env = gym.make("CartPole-v1", render_mode="human")
model = DQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
@ -70,13 +70,12 @@ This example is only to demonstrate the use of the library and its functions, an
model = DQN.load("dqn_cartpole")
obs = env.reset()
obs, info = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()
Results

View file

@ -65,7 +65,6 @@ This example is only to demonstrate the use of the library and its functions, an
from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.common.envs import BitFlippingEnv
from stable_baselines3.common.vec_env import DummyVecEnv
model_class = DQN # works also with SAC, DDPG and TD3
N_BITS = 15
@ -96,13 +95,12 @@ This example is only to demonstrate the use of the library and its functions, an
# HER must be loaded with the env
model = model_class.load("./her_bit_env", env=env)
obs = env.reset()
obs, info = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, _ = env.step(action)
if done:
obs = env.reset()
obs, reward, terminated, truncated, _ = env.step(action)
if terminated or truncated:
obs, info = env.reset()
Results

View file

@ -71,9 +71,9 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
env = make_vec_env("CartPole-v1", n_envs=4)
vec_env = make_vec_env("CartPole-v1", n_envs=4)
model = PPO("MlpPolicy", env, verbose=1)
model = PPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo_cartpole")
@ -81,11 +81,11 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
model = PPO.load("ppo_cartpole")
obs = env.reset()
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
Results

View file

@ -69,11 +69,10 @@ This example is only to demonstrate the use of the library and its functions, an
.. code-block:: python
import gymnasium as gym
import numpy as np
from stable_baselines3 import SAC
env = gym.make("Pendulum-v1")
env = gym.make("Pendulum-v1", render_mode="human")
model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
@ -83,13 +82,12 @@ This example is only to demonstrate the use of the library and its functions, an
model = SAC.load("sac_pendulum")
obs = env.reset()
obs, info = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()
Results

View file

@ -67,7 +67,7 @@ This example is only to demonstrate the use of the library and its functions, an
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make("Pendulum-v1")
env = gym.make("Pendulum-v1", render_mode="rgb_array")
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
@ -76,17 +76,17 @@ This example is only to demonstrate the use of the library and its functions, an
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("td3_pendulum")
env = model.get_env()
vec_env = model.get_env()
del model # remove to demonstrate saving and loading
model = TD3.load("td3_pendulum")
obs = env.reset()
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
Results
-------

View file

@ -149,7 +149,7 @@ setup(
url="https://github.com/DLR-RM/stable-baselines3",
author_email="antonin.raffin@dlr.de",
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
"gym openai stable baselines toolbox python data-science",
"gymnasium gym openai stable baselines toolbox python data-science",
license="MIT",
long_description=long_description,
long_description_content_type="text/markdown",