diff --git a/README.md b/README.md index c0398b5..ac13763 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin | Type hints | :heavy_check_mark: | -### Planned features (v1.1+) +### Planned features Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones). @@ -48,11 +48,13 @@ A migration guide from SB2 to SB3 can be found in the [documentation](https://st Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/) -## RL Baselines3 Zoo: A Collection of Trained RL Agents +## RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents -[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines3. +[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework for Reinforcement Learning (RL). -It also provides basic scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos. +It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos. + +In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings. Goals of this repository: @@ -110,9 +112,9 @@ import gym from stable_baselines3 import PPO -env = gym.make('CartPole-v1') +env = gym.make("CartPole-v1") -model = PPO('MlpPolicy', env, verbose=1) +model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000) obs = env.reset() diff --git a/docs/_static/img/net_arch.png b/docs/_static/img/net_arch.png new file mode 100644 index 0000000..20143ac Binary files /dev/null and b/docs/_static/img/net_arch.png differ diff --git a/docs/_static/img/sb3_loop.png b/docs/_static/img/sb3_loop.png new file mode 100644 index 0000000..a0bdb3e Binary files /dev/null and b/docs/_static/img/sb3_loop.png differ diff --git a/docs/_static/img/sb3_policy.png b/docs/_static/img/sb3_policy.png new file mode 100644 index 0000000..d79389d Binary files /dev/null and b/docs/_static/img/sb3_policy.png differ diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index f8aecfe..03829ff 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -13,9 +13,49 @@ and other type of input features (MlpPolicies). which handles bounds more correctly. +SB3 Policy +^^^^^^^^^^ -Custom Policy Architecture -^^^^^^^^^^^^^^^^^^^^^^^^^^ +SB3 networks are separated into two mains parts (see figure below): + +- A features extractor (usually shared between actor and critic when applicable, to save computation) + whose role is to extract features (i.e. convert to a feature vector) from high-dimensional observations, for instance, a CNN that extracts features from images. + This is the ``features_extractor_class`` parameter. You can change the default parameters of that features extractor + by passing a ``features_extractor_kwargs`` parameter. + +- A (fully-connected) network that maps the features to actions/value. Its architecture is controlled by the ``net_arch`` parameter. + + +.. note:: + + All observations are first pre-processed (e.g. images are normalized, discrete obs are converted to one-hot vectors, ...) before being fed to the features extractor. + In the case of vector observations, the features extractor is just a ``Flatten`` layer. + + +.. image:: ../_static/img/net_arch.png + + +SB3 policies are usually composed of several networks (actor/critic networks + target networks when applicable) together +with the associated optimizers. + +Each of these network have a features extractor followed by a fully-connected network. + +.. note:: + + When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology. + In SB3, "policy" refers to the class that handles all the networks useful for training, + so not only the network used to predict actions (the "learned controller"). + + + +.. image:: ../_static/img/sb3_policy.png + + +.. .. figure:: https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif + + +Custom Network Architecture +^^^^^^^^^^^^^^^^^^^^^^^^^^^ One way of customising the policy network architecture is to pass arguments when creating the model, using ``policy_kwargs`` parameter: diff --git a/docs/guide/developer.rst b/docs/guide/developer.rst index d930594..e69e510 100644 --- a/docs/guide/developer.rst +++ b/docs/guide/developer.rst @@ -31,6 +31,9 @@ Each algorithm has two main methods: - ``.train()`` which updates the parameters using samples from the buffer +.. image:: ../_static/img/sb3_loop.png + + Where to start? =============== diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst index fa23584..9b25f95 100644 --- a/docs/guide/migration.rst +++ b/docs/guide/migration.rst @@ -98,7 +98,7 @@ Base-class (all algorithms) Policies ^^^^^^^^ -- ``cnn_extractor`` -> ``feature_extractor``, as ``feature_extractor`` in now used with ``MlpPolicy`` too +- ``cnn_extractor`` -> ``features_extractor``, as ``features_extractor`` in now used with ``MlpPolicy`` too A2C ^^^ diff --git a/docs/guide/rl_zoo.rst b/docs/guide/rl_zoo.rst index c592978..9b25573 100644 --- a/docs/guide/rl_zoo.rst +++ b/docs/guide/rl_zoo.rst @@ -4,9 +4,11 @@ RL Baselines3 Zoo ================== -`RL Baselines3 Zoo `_. is a collection of pre-trained Reinforcement Learning agents using -Stable-Baselines3. -It also provides basic scripts for training, evaluating agents, tuning hyperparameters and recording videos. +`RL Baselines3 Zoo `_ is a training framework for Reinforcement Learning (RL). + +It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos. + +In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings. Goals of this repository: diff --git a/docs/index.rst b/docs/index.rst index 61ac1d5..4f41153 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,9 +12,9 @@ It is the next major version of `Stable Baselines `_ and `#351 `_) + This will be a backward incompatible change (model trained with previous version of ``HER`` won't work with the new version). + + New Features: ^^^^^^^^^^^^^ - Added support for ``custom_objects`` when loading models @@ -24,7 +33,9 @@ Documentation: - Added new project using SB3: rl_reach (@PierreExeter) - Added note about slow-down when switching to PyTorch - Add a note on continual learning and resetting environment - +- Updated RL-Zoo to reflect the fact that is it more than a collection of trained agents +- Added images to illustrate the training loop and custom policies (created with https://excalidraw.com/) +- Updated the custom policy section Pre-Release 0.11.1 (2021-02-27) ------------------------------- diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index db805d3..d3827e7 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.0rc2 +1.0