mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Add learning rate schedule example (#248)
* Add learning rate schedule example * Update docs/guide/examples.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Address comments Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
parent
723b341c61
commit
e747e7e2b3
4 changed files with 47 additions and 1 deletions
|
|
@ -422,6 +422,50 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
|
|||
obs = env.reset()
|
||||
|
||||
|
||||
Learning Rate Schedule
|
||||
----------------------
|
||||
|
||||
All algorithms allow you to pass a learning rate schedule that takes as input the current progress remaining (from 1 to 0).
|
||||
``PPO``'s ``clip_range``` parameter also accepts such schedule.
|
||||
|
||||
The `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ already includes
|
||||
linear and constant schedules.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
|
||||
def linear_schedule(initial_value: float) -> Callable[[float], float]:
|
||||
"""
|
||||
Linear learning rate schedule.
|
||||
|
||||
:param initial_value: Initial learning rate.
|
||||
:return: schedule that computes
|
||||
current learning rate depending on remaining progress
|
||||
"""
|
||||
def func(progress_remaining: float) -> float:
|
||||
"""
|
||||
Progress will decrease from 1 (beginning) to 0.
|
||||
|
||||
:param progress_remaining:
|
||||
:return: current learning rate
|
||||
"""
|
||||
return progress_remaining * initial_value
|
||||
|
||||
return func
|
||||
|
||||
# Initial learning rate of 0.001
|
||||
model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1)
|
||||
model.learn(total_timesteps=20000)
|
||||
# By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets.
|
||||
# progress_remaining = 1.0 - (num_timesteps / total_timesteps)
|
||||
model.learn(total_timesteps=10000, reset_num_timesteps=True)
|
||||
|
||||
|
||||
Advanced Saving and Loading
|
||||
---------------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ Documentation:
|
|||
- Fix migration doc for ``A2C`` (epsilon parameter)
|
||||
- Fix ``clip_range`` docstring
|
||||
- Fix duplicated parameter in ``EvalCallback`` docstring (thanks @tfederico)
|
||||
- Added example of learning rate schedule
|
||||
|
||||
Pre-Release 0.10.0 (2020-10-28)
|
||||
-------------------------------
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: The learning rate, it can be a function
|
||||
of the current progress remaining (from 1 to 0)
|
||||
:param n_steps: The number of steps to run for each environment per update
|
||||
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
||||
:param gamma: Discount factor
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: The learning rate, it can be a function
|
||||
of the current progress (from 1 to 0)
|
||||
of the current progress remaining (from 1 to 0)
|
||||
:param buffer_size: size of the replay buffer
|
||||
:param learning_starts: how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: Minibatch size for each gradient update
|
||||
|
|
|
|||
Loading…
Reference in a new issue