From e747e7e2b3f92a6650a7d69fd839544db4c5ffba Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 2 Dec 2020 14:54:18 +0100 Subject: [PATCH] Add learning rate schedule example (#248) * Add learning rate schedule example * Update docs/guide/examples.rst Co-authored-by: Adam Gleave * Address comments Co-authored-by: Adam Gleave --- docs/guide/examples.rst | 44 ++++++++++++++++++++++++++++++++++++ docs/misc/changelog.rst | 1 + stable_baselines3/a2c/a2c.py | 1 + stable_baselines3/dqn/dqn.py | 2 +- 4 files changed, 47 insertions(+), 1 deletion(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 4b85e02..ed04148 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -422,6 +422,50 @@ The parking env is a goal-conditioned continuous control task, in which the vehi obs = env.reset() +Learning Rate Schedule +---------------------- + +All algorithms allow you to pass a learning rate schedule that takes as input the current progress remaining (from 1 to 0). +``PPO``'s ``clip_range``` parameter also accepts such schedule. + +The `RL Zoo `_ already includes +linear and constant schedules. + + +.. code-block:: python + + from typing import Callable + + from stable_baselines3 import PPO + + + def linear_schedule(initial_value: float) -> Callable[[float], float]: + """ + Linear learning rate schedule. + + :param initial_value: Initial learning rate. + :return: schedule that computes + current learning rate depending on remaining progress + """ + def func(progress_remaining: float) -> float: + """ + Progress will decrease from 1 (beginning) to 0. + + :param progress_remaining: + :return: current learning rate + """ + return progress_remaining * initial_value + + return func + + # Initial learning rate of 0.001 + model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1) + model.learn(total_timesteps=20000) + # By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets. + # progress_remaining = 1.0 - (num_timesteps / total_timesteps) + model.learn(total_timesteps=10000, reset_num_timesteps=True) + + Advanced Saving and Loading --------------------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b154e49..77b02cd 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -50,6 +50,7 @@ Documentation: - Fix migration doc for ``A2C`` (epsilon parameter) - Fix ``clip_range`` docstring - Fix duplicated parameter in ``EvalCallback`` docstring (thanks @tfederico) +- Added example of learning rate schedule Pre-Release 0.10.0 (2020-10-28) ------------------------------- diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index b78bf22..d2dd7f1 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -24,6 +24,7 @@ class A2C(OnPolicyAlgorithm): :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param gamma: Discount factor diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index f715a15..180651a 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -22,7 +22,7 @@ class DQN(OffPolicyAlgorithm): :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function - of the current progress (from 1 to 0) + of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update