Add learning rate schedule example (#248)

* Add learning rate schedule example

* Update docs/guide/examples.rst

Co-authored-by: Adam Gleave <adam@gleave.me>

* Address comments

Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
Antonin RAFFIN 2020-12-02 14:54:18 +01:00 committed by GitHub
parent 723b341c61
commit e747e7e2b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 1 deletions

View file

@ -422,6 +422,50 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
obs = env.reset()
Learning Rate Schedule
----------------------
All algorithms allow you to pass a learning rate schedule that takes as input the current progress remaining (from 1 to 0).
``PPO``'s ``clip_range``` parameter also accepts such schedule.
The `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ already includes
linear and constant schedules.
.. code-block:: python
from typing import Callable
from stable_baselines3 import PPO
def linear_schedule(initial_value: float) -> Callable[[float], float]:
"""
Linear learning rate schedule.
:param initial_value: Initial learning rate.
:return: schedule that computes
current learning rate depending on remaining progress
"""
def func(progress_remaining: float) -> float:
"""
Progress will decrease from 1 (beginning) to 0.
:param progress_remaining:
:return: current learning rate
"""
return progress_remaining * initial_value
return func
# Initial learning rate of 0.001
model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1)
model.learn(total_timesteps=20000)
# By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets.
# progress_remaining = 1.0 - (num_timesteps / total_timesteps)
model.learn(total_timesteps=10000, reset_num_timesteps=True)
Advanced Saving and Loading
---------------------------------

View file

@ -50,6 +50,7 @@ Documentation:
- Fix migration doc for ``A2C`` (epsilon parameter)
- Fix ``clip_range`` docstring
- Fix duplicated parameter in ``EvalCallback`` docstring (thanks @tfederico)
- Added example of learning rate schedule
Pre-Release 0.10.0 (2020-10-28)
-------------------------------

View file

@ -24,6 +24,7 @@ class A2C(OnPolicyAlgorithm):
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param gamma: Discount factor

View file

@ -22,7 +22,7 @@ class DQN(OffPolicyAlgorithm):
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress (from 1 to 0)
of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update