stable-baselines3/docs/guide/examples.rst

707 lines
24 KiB
ReStructuredText
Raw Normal View History

2020-05-07 14:08:23 +00:00
.. _examples:
Examples
========
Try it online with Colab Notebooks!
-----------------------------------
All the following examples can be executed online using Google colab |colab|
notebooks:
2020-05-19 08:40:52 +00:00
- `Full Tutorial <https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3>`_
2020-05-07 14:08:23 +00:00
- `All Notebooks <https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3>`_
- `Getting Started`_
2020-05-08 10:14:33 +00:00
- `Training, Saving, Loading`_
2020-06-17 10:47:09 +00:00
- `Multiprocessing`_
- `Monitor Training and Plotting`_
- `Atari Games`_
2020-05-08 09:58:43 +00:00
- `RL Baselines zoo`_
- `PyBullet`_
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
- `Hindsight Experience Replay`_
- `Advanced Saving and Loading`_
2020-05-07 14:08:23 +00:00
.. _Getting Started: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb
2020-05-08 09:58:43 +00:00
.. _Training, Saving, Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb
.. _Multiprocessing: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
.. _Monitor Training and Plotting: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb
.. _Atari Games: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
.. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
.. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb
.. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
.. _Advanced Saving and Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb
2020-05-07 14:08:23 +00:00
.. |colab| image:: ../_static/img/colab.svg
Basic Usage: Training, Saving, Loading
--------------------------------------
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
In the following example, we will train, save and load a DQN model on the Lunar Lander environment.
2020-05-07 14:15:32 +00:00
2020-05-08 10:14:33 +00:00
.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb
2020-05-07 14:08:23 +00:00
.. figure:: https://cdn-images-1.medium.com/max/960/1*f4VZPKOI0PYNWiwt0la0Rg.gif
Lunar Lander Environment
.. note::
LunarLander requires the python package ``box2d``.
You can install it using ``apt install swig`` and then ``pip install box2d box2d-kengz``
.. .. note::
.. ``load`` function re-creates model from scratch on each call, which can be slow.
.. If you need to e.g. evaluate same model with multiple different sets of parameters, consider
.. using ``load_parameters`` instead.
.. code-block:: python
import gym
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
from stable_baselines3 import DQN
2020-05-07 14:08:23 +00:00
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make('LunarLander-v2')
# Instantiate the agent
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
model = DQN('MlpPolicy', env, verbose=1)
2020-05-07 14:08:23 +00:00
# Train the agent
model.learn(total_timesteps=int(2e5))
# Save the agent
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
model.save("dqn_lunar")
2020-05-07 14:08:23 +00:00
del model # delete trained model to demonstrate loading
# Load the trained agent
model = DQN.load("dqn_lunar", env=env)
2020-05-07 14:08:23 +00:00
# Evaluate the agent
# NOTE: If you use wrappers with your environment that modify rewards,
# this will be reflected here. To evaluate with original rewards,
# wrap environment in a "Monitor" wrapper before other wrappers.
2020-05-07 14:08:23 +00:00
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
# Enjoy trained agent
obs = env.reset()
for i in range(1000):
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
action, _states = model.predict(obs, deterministic=True)
2020-05-07 14:08:23 +00:00
obs, rewards, dones, info = env.step(action)
env.render()
Multiprocessing: Unleashing the Power of Vectorized Environments
----------------------------------------------------------------
2020-06-17 10:47:09 +00:00
.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
2020-05-07 14:08:23 +00:00
.. figure:: https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif
CartPole Environment
.. code-block:: python
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
2020-05-07 14:08:23 +00:00
from stable_baselines3.common.utils import set_random_seed
def make_env(env_id, rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = gym.make(env_id)
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
if __name__ == '__main__':
env_id = "CartPole-v1"
num_cpu = 4 # Number of processes to use
# Create the vectorized environment
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
# Stable Baselines provides you with make_vec_env() helper
# which does exactly the previous steps for you:
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0)
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=25000)
obs = env.reset()
for _ in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Using Callback: Monitoring Training
-----------------------------------
.. note::
We recommend reading the `Callback section <callbacks.html>`_
You can define a custom callback function that will be called inside the agent.
This could be useful when you want to monitor training, for instance display live
learning curves in Tensorboard (or in Visdom) or save the best agent.
If your callback returns False, training is aborted early.
2020-06-17 10:47:09 +00:00
.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb
2020-05-07 14:08:23 +00:00
.. code-block:: python
import os
import gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import TD3
from stable_baselines3.common import results_plotter
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.callbacks import BaseCallback
class SaveOnBestTrainingRewardCallback(BaseCallback):
"""
Callback for saving a model (the check is done every ``check_freq`` steps)
based on the training reward (in practice, we recommend using ``EvalCallback``).
:param check_freq: (int)
:param log_dir: (str) Path to the folder where the model will be saved.
It must contains the file created by the ``Monitor`` wrapper.
:param verbose: (int)
"""
def __init__(self, check_freq: int, log_dir: str, verbose=1):
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
self.check_freq = check_freq
self.log_dir = log_dir
self.save_path = os.path.join(log_dir, 'best_model')
self.best_mean_reward = -np.inf
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
# Retrieve training reward
x, y = ts2xy(load_results(self.log_dir), 'timesteps')
if len(x) > 0:
# Mean training reward over the last 100 episodes
mean_reward = np.mean(y[-100:])
if self.verbose > 0:
print("Num timesteps: {}".format(self.num_timesteps))
print("Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(self.best_mean_reward, mean_reward))
# New best model, you could save the agent here
if mean_reward > self.best_mean_reward:
self.best_mean_reward = mean_reward
# Example for saving best model
if self.verbose > 0:
print("Saving new best model to {}".format(self.save_path))
self.model.save(self.save_path)
return True
# Create log dir
log_dir = "tmp/"
os.makedirs(log_dir, exist_ok=True)
# Create and wrap the environment
env = gym.make('LunarLanderContinuous-v2')
env = Monitor(env, log_dir)
# Add some action noise for exploration
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
# Because we use parameter noise, we should use a MlpPolicy with layer normalization
2020-06-17 10:47:09 +00:00
model = TD3('MlpPolicy', env, action_noise=action_noise, verbose=0)
2020-05-07 14:08:23 +00:00
# Create the callback: check every 1000 steps
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
# Train the agent
timesteps = 1e5
model.learn(total_timesteps=int(timesteps), callback=callback)
plot_results([log_dir], timesteps, results_plotter.X_TIMESTEPS, "TD3 LunarLander")
plt.show()
Atari Games
-----------
.. figure:: ../_static/img/breakout.gif
Trained A2C agent on Breakout
.. figure:: https://cdn-images-1.medium.com/max/960/1*UHYJE7lF8IDZS_U5SsAFUQ.gif
Pong Environment
Training a RL agent on Atari games is straightforward thanks to ``make_atari_env`` helper function.
It will do `all the preprocessing <https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/>`_
and multiprocessing for you.
2020-06-17 10:47:09 +00:00
.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
2020-05-07 14:08:23 +00:00
..
.. code-block:: python
from stable_baselines3.common.env_util import make_atari_env
2020-05-07 14:08:23 +00:00
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import A2C
# There already exists an environment generator
# that will make and wrap atari environments correctly.
# Here we are also multi-worker training (n_envs=4 => 4 environments)
env = make_atari_env('PongNoFrameskip-v4', n_envs=4, seed=0)
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)
model = A2C('CnnPolicy', env, verbose=1)
model.learn(total_timesteps=25000)
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
PyBullet: Normalizing input features
------------------------------------
Normalizing input features may be essential to successful training of an RL agent
(by default, images are scaled but not other types of input),
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
2020-05-07 14:08:23 +00:00
will compute a running average and standard deviation of input features (it can do the same for rewards).
.. note::
you need to install pybullet with ``pip install pybullet``
.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
2020-05-07 14:08:23 +00:00
.. code-block:: python
import os
2020-05-07 14:08:23 +00:00
import gym
2020-05-09 12:38:57 +00:00
import pybullet_envs
2020-05-07 14:08:23 +00:00
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3 import PPO
env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
2020-05-09 12:38:57 +00:00
# Automatically normalize the input features and reward
2020-05-07 14:08:23 +00:00
env = VecNormalize(env, norm_obs=True, norm_reward=True,
clip_obs=10.)
2020-05-07 14:15:32 +00:00
model = PPO('MlpPolicy', env)
2020-05-07 14:08:23 +00:00
model.learn(total_timesteps=2000)
# Don't forget to save the VecNormalize statistics when saving the agent
log_dir = "/tmp/"
2020-05-09 12:38:57 +00:00
model.save(log_dir + "ppo_halfcheetah")
stats_path = os.path.join(log_dir, "vec_normalize.pkl")
env.save(stats_path)
# To demonstrate loading
del model, env
# Load the saved statistics
env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
env = VecNormalize.load(stats_path, env)
# do not update them at test time
env.training = False
# reward normalization is not needed at test time
env.norm_reward = False
2020-05-07 14:08:23 +00:00
# Load the agent
model = PPO.load(log_dir + "ppo_halfcheetah", env=env)
2020-05-07 14:08:23 +00:00
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
Hindsight Experience Replay (HER)
---------------------------------
For this example, we are using `Highway-Env <https://github.com/eleurent/highway-env>`_ by `@eleurent <https://github.com/eleurent>`_.
.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
.. figure:: https://raw.githubusercontent.com/eleurent/highway-env/gh-media/docs/media/parking-env.gif
The highway-parking-v0 environment.
The parking env is a goal-conditioned continuous control task, in which the vehicle must park in a given space with the appropriate heading.
.. note::
The hyperparameters in the following example were optimized for that environment.
.. code-block:: python
import gym
import highway_env
import numpy as np
from stable_baselines3 import HER, SAC, DDPG, TD3
from stable_baselines3.common.noise import NormalActionNoise
env = gym.make("parking-v0")
# Create 4 artificial transitions per real transition
n_sampled_goal = 4
# SAC hyperparams:
model = HER(
"MlpPolicy",
env,
SAC,
n_sampled_goal=n_sampled_goal,
goal_selection_strategy="future",
# IMPORTANT: because the env is not wrapped with a TimeLimit wrapper
# we have to manually specify the max number of steps per episode
max_episode_length=100,
verbose=1,
buffer_size=int(1e6),
learning_rate=1e-3,
gamma=0.95,
batch_size=256,
online_sampling=True,
policy_kwargs=dict(net_arch=[256, 256, 256]),
)
model.learn(int(2e5))
model.save("her_sac_highway")
# Load saved model
TD3 Code review (#245) * Removed unneeded overrides of feature_extractor and normalize_images in the TD3 Actor. * Add learning rate schedule example (#248) * Add learning rate schedule example * Update docs/guide/examples.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Address comments Co-authored-by: Adam Gleave <adam@gleave.me> * Add supported action spaces checks (#254) * Add supported action spaces checks * Address comment * Use `pass` in an abstractmethod instead of deleting the arguments. * Remove the "deterministic" keyword from the forward method of the TD3 Actor since it always is deterministic anyways. * Rename _get_data to _get_data_to_reconstruct_model. _get_data was too generic and could have meant anything. * Remove the n_episodes_rollout parameter and allow passing tuples as train_freq instead. * Fix docstring of `train_freq` parameter. * Black fixes. * Fix TD3 delayed update + rename `_get_data()` * Fix TD3 test * Normalize `train_freq` to a tuple in the constructor and turn the warning into an assert. * Make one step the default train frequency. * Black fixes. * Change np.bool to bool. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of the off policy algorithm. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of HER. * Use named tuple for train freq * Rename train_freq to train_every and TrainFreq to ExperienceDuration. Also add some type annotations and documentation. * Black fixes. * Revert to train_freq * Fix terminal observation issues * Typo * Fix action noise bug in HER * Add assert when loading HER models * Update version Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Adam Gleave <adam@gleave.me>
2021-02-27 16:33:50 +00:00
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
model = HER.load("her_sac_highway", env=env)
obs = env.reset()
# Evaluate the agent
episode_reward = 0
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
episode_reward += reward
if done or info.get("is_success", False):
print("Reward:", episode_reward, "Success?", info.get("is_success", False))
episode_reward = 0.0
obs = env.reset()
Learning Rate Schedule
----------------------
All algorithms allow you to pass a learning rate schedule that takes as input the current progress remaining (from 1 to 0).
``PPO``'s ``clip_range``` parameter also accepts such schedule.
The `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ already includes
linear and constant schedules.
.. code-block:: python
from typing import Callable
from stable_baselines3 import PPO
def linear_schedule(initial_value: float) -> Callable[[float], float]:
"""
Linear learning rate schedule.
:param initial_value: Initial learning rate.
:return: schedule that computes
current learning rate depending on remaining progress
"""
def func(progress_remaining: float) -> float:
"""
Progress will decrease from 1 (beginning) to 0.
:param progress_remaining:
:return: current learning rate
"""
return progress_remaining * initial_value
return func
# Initial learning rate of 0.001
model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1)
model.learn(total_timesteps=20000)
# By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets.
# progress_remaining = 1.0 - (num_timesteps / total_timesteps)
model.learn(total_timesteps=10000, reset_num_timesteps=True)
Advanced Saving and Loading
---------------------------------
In this example, we show how to use some advanced features of Stable-Baselines3 (SB3):
how to easily create a test environment to evaluate an agent periodically,
use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
By default, the replay buffer is not saved when calling ``model.save()``, in order to save space on the disk (a replay buffer can be up to several GB when using images).
However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` method to save it separately.
Stable-Baselines3 automatic creation of an environment for evaluation.
For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment while creating the agent.
Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.
.. note::
For training model after loading it, we recommend loading the replay buffer to ensure stable learning (for off-policy algorithms).
You also need to pass ``reset_num_timesteps=True`` to ``learn`` function which initializes the environment
and agent for training if a new environment was created since saving the model.
.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb
.. code-block:: python
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.sac.policies import MlpPolicy
# Create the model, the training environment
# and the test environment (for evaluation)
model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1,
learning_rate=1e-3, create_eval_env=True)
# Evaluate the model every 1000 steps on 5 test episodes
# and save the evaluation to the "logs/" folder
model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/")
# save the model
model.save("sac_pendulum")
# the saved model does not contain the replay buffer
loaded_model = SAC.load("sac_pendulum")
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
# now save the replay buffer too
model.save_replay_buffer("sac_replay_buffer")
# load it into the loaded_model
loaded_model.load_replay_buffer("sac_replay_buffer")
# now the loaded replay is not empty anymore
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
# Save the policy independently from the model
# Note: if you don't save the complete model with `model.save()`
# you cannot continue training afterward
policy = model.policy
policy.save("sac_policy_pendulum")
# Retrieve the environment
env = model.get_env()
# Evaluate the policy
mean_reward, std_reward = evaluate_policy(policy, env, n_eval_episodes=10, deterministic=True)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
# Load the policy independently from the model
saved_policy = MlpPolicy.load("sac_policy_pendulum")
# Evaluate the loaded policy
mean_reward, std_reward = evaluate_policy(saved_policy, env, n_eval_episodes=10, deterministic=True)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
Accessing and modifying model parameters
----------------------------------------
You can access model's parameters via ``load_parameters`` and ``get_parameters`` functions,
or via ``model.policy.state_dict()`` (and ``load_state_dict()``),
which use dictionaries that map variable names to PyTorch tensors.
These functions are useful when you need to e.g. evaluate large set of models with same network structure,
visualize different layers of the network or modify parameters manually.
Policies also offers a simple way to save/load weights as a NumPy vector, using ``parameters_to_vector()``
and ``load_from_vector()`` method.
Following example demonstrates reading parameters, modifying some of them and loading them to model
by implementing `evolution strategy (es) <http://blog.otoro.net/2017/10/29/visual-evolution-strategies/>`_
for solving the ``CartPole-v1`` environment. The initial guess for parameters is obtained by running
A2C policy gradient updates on the model.
.. code-block:: python
from typing import Dict
import gym
import numpy as np
import torch as th
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
def mutate(params: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]:
"""Mutate parameters by adding normal noise to them"""
return dict((name, param + th.randn_like(param)) for name, param in params.items())
# Create policy with a small network
model = A2C(
"MlpPolicy",
"CartPole-v1",
ent_coef=0.0,
policy_kwargs={"net_arch": [32]},
seed=0,
learning_rate=0.05,
)
# Use traditional actor-critic policy gradient updates to
# find good initial parameters
model.learn(total_timesteps=10000)
# Include only variables with "policy", "action" (policy) or "shared_net" (shared layers)
# in their name: only these ones affect the action.
# NOTE: you can retrieve those parameters using model.get_parameters() too
mean_params = dict(
(key, value)
for key, value in model.policy.state_dict().items()
if ("policy" in key or "shared_net" in key or "action" in key)
)
# population size of 50 invdiduals
pop_size = 50
# Keep top 10%
n_elite = pop_size // 10
# Retrieve the environment
env = model.get_env()
for iteration in range(10):
# Create population of candidates and evaluate them
population = []
for population_i in range(pop_size):
candidate = mutate(mean_params)
# Load new policy parameters to agent.
# Tell function that it should only update parameters
# we give it (policy parameters)
model.policy.load_state_dict(candidate, strict=False)
# Evaluate the candidate
fitness, _ = evaluate_policy(model, env)
population.append((candidate, fitness))
# Take top 10% and use average over their parameters as next mean parameter
top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite]
mean_params = dict(
(
name,
th.stack([candidate[0][name] for candidate in top_candidates]).mean(dim=0),
)
for name in mean_params.keys()
)
mean_fitness = sum(top_candidate[1] for top_candidate in top_candidates) / n_elite
print(f"Iteration {iteration + 1:<3} Mean top fitness: {mean_fitness:.2f}")
print(f"Best fitness: {top_candidates[0][1]:.2f}")
2020-05-07 14:08:23 +00:00
Record a Video
--------------
Record a mp4 video (here using a random agent).
.. note::
It requires ``ffmpeg`` or ``avconv`` to be installed on the machine.
.. code-block:: python
import gym
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
env_id = 'CartPole-v1'
video_folder = 'logs/videos/'
video_length = 100
env = DummyVecEnv([lambda: gym.make(env_id)])
obs = env.reset()
# Record the video starting at the first step
env = VecVideoRecorder(env, video_folder,
record_video_trigger=lambda x: x == 0, video_length=video_length,
name_prefix="random-agent-{}".format(env_id))
env.reset()
for _ in range(video_length + 1):
action = [env.action_space.sample()]
obs, _, _, _ = env.step(action)
# Save the video
env.close()
Bonus: Make a GIF of a Trained Agent
------------------------------------
.. note::
For Atari games, you need to use a screen recorder such as `Kazam <https://launchpad.net/kazam>`_.
And then convert the video using `ffmpeg <https://superuser.com/questions/556029/how-do-i-convert-a-video-to-gif-using-ffmpeg-with-reasonable-quality>`_
.. code-block:: python
import imageio
import numpy as np
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "LunarLander-v2").learn(100000)
images = []
obs = model.env.reset()
img = model.env.render(mode='rgb_array')
for i in range(350):
images.append(img)
action, _ = model.predict(obs)
obs, _, _ ,_ = model.env.step(action)
img = model.env.render(mode='rgb_array')
imageio.mimsave('lander_a2c.gif', [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)