* Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
10 KiB
WARNING: Stable Baselines3 is currently in a beta version, breaking changes may occur before 1.0 is released
Stable Baselines3
Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of Stable Baselines.
You can read a detailed presentation of Stable Baselines in the Medium article.
These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details.
Note: despite its simplicity of use, Stable Baselines3 (SB3) assumes you have some knowledge about Reinforcement Learning (RL). You should not utilize this library without some practice. To that extent, we provide good resources in the documentation to get started with RL.
Main Features
| Features | Stable-Baselines3 |
|---|---|
| State of the art RL methods | ✔️ |
| Documentation | ✔️ |
| Custom environments | ✔️ |
| Custom policies | ✔️ |
| Common interface | ✔️ |
| Ipython / Notebook friendly | ✔️ |
| Tensorboard support | ✔️ |
| PEP8 code style | ✔️ |
| Custom callback | ✔️ |
| High code coverage | ✔️ |
| Type hints | ✔️ |
Roadmap to V1.0
Please look at the issue for more details. Planned features:
- DDPG (you can use its successor TD3 for now)
- [ ] HER
Planned features (v1.1+)
- DQN extensions (prioritized replay, double q-learning, ...)
- Support for
TupleandDictobservation spaces - [ ] Recurrent Policies
- TRPO
Migration guide
TODO: migration guide from Stable-Baselines in the documentation
Documentation
Documentation is available online: https://stable-baselines3.readthedocs.io/
RL Baselines3 Zoo: A Collection of Trained RL Agents
RL Baselines3 Zoo. is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines3.
It also provides basic scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.
Goals of this repository:
- Provide a simple interface to train and enjoy RL agents
- Benchmark the different Reinforcement Learning algorithms
- Provide tuned hyperparameters for each environment and RL algorithm
- Have fun with the trained agents!
Github repo: https://github.com/DLR-RM/rl-baselines3-zoo
Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html
Installation
Note: Stable-Baselines3 supports PyTorch 1.4+.
Prerequisites
Stable Baselines3 requires python 3.6+.
Windows 10
To install stable-baselines on Windows, please look at the documentation.
Install using pip
Install the Stable Baselines3 package:
pip install stable-baselines3[extra]
This includes an optional dependencies like Tensorboard, OpenCV or atari-py to train on atari games. If you do not need those, you can use:
pip install stable-baselines3
Please read the documentation for more details and alternatives (from source, using docker).
Example
Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.
Here is a quick example of how to train and run PPO on a cartpole environment:
import gym
from stable_baselines3 import PPO
env = gym.make('CartPole-v1')
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=10000)
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
env.close()
Or just train a model with a one liner if the environment is registered in Gym and if the policy is registered:
from stable_baselines3 import PPO
model = PPO('MlpPolicy', 'CartPole-v1').learn(10000)
Please read the documentation for more examples.
Try it online with Colab Notebooks !
All the following examples can be executed online using Google colab notebooks:
- Full Tutorial
- All Notebooks
- Getting Started
- Training, Saving, Loading
- Multiprocessing
- Monitor Training and Plotting
- Atari Games
- RL Baselines Zoo
Implemented Algorithms
| Name | Recurrent | Box |
Discrete |
MultiDiscrete |
MultiBinary |
Multi Processing |
|---|---|---|---|---|---|---|
| A2C | ❌ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| PPO | ❌ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| SAC | ❌ | ✔️ | ❌ | ❌ | ❌ | ❌ |
| TD3 | ❌ | ✔️ | ❌ | ❌ | ❌ | ❌ |
Actions gym.spaces:
Box: A N-dimensional box that containes every point in the action space.Discrete: A list of possible actions, where each timestep only one of the actions can be used.MultiDiscrete: A list of possible actions, where each timestep only one action of each discrete set can be used.MultiBinary: A list of possible actions, where each timestep any of the actions can be used in any combination.
Testing the installation
All unit tests in stable baselines3 can be run using pytest runner:
pip install pytest pytest-cov
make pytest
You can also do a static type check using pytype:
pip install pytype
make type
Codestyle check with flake8:
pip install flake8
make lint
Projects Using Stable-Baselines3
We try to maintain a list of project using stable-baselines3 in the documentation, please tell us when if you want your project to appear on this page ;)
Citing the Project
To cite this repository in publications:
@misc{stable-baselines3,
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
title = {Stable Baselines3},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
}
Maintainers
Stable-Baselines3 is currently maintained by Ashley Hill (aka @hill-a), Antonin Raffin (aka @araffin), Maximilian Ernestus (aka @erniejunior), Adam Gleave (@AdamGleave) and Anssi Kanervisto (@Miffyli).
Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.
How To Contribute
To any interested in making the baselines better, there is still some documentation that needs to be done. If you want to contribute, please read CONTRIBUTING.md guide first.
Acknowledgments
The initial work to develop Stable Baselines3 was partially funded by the project Reduced Complexity Models from the Helmholtz-Gemeinschaft Deutscher Forschungszentren.
The original version, Stable Baselines, was created in the robotics lab U2IS (INRIA Flowers team) at ENSTA ParisTech.
Logo credits: L.M. Tenkes