mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-07 00:13:37 +00:00
Add docstring
This commit is contained in:
parent
d67822718c
commit
df1e7aa000
2 changed files with 8 additions and 1 deletions
|
|
@ -129,6 +129,14 @@ class BaseRLModel(object):
|
|||
self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps)
|
||||
|
||||
def _update_learning_rate(self, optimizers):
|
||||
"""
|
||||
Update the optimizers learning rate using the current learning rate schedule
|
||||
and the current progress (from 1 to 0).
|
||||
|
||||
:param optimizers: ([th.optim.Optimizer] or Optimizer) An optimizer
|
||||
or a list of optimizer.
|
||||
"""
|
||||
# Log the current learning rate
|
||||
logger.logkv("learning_rate", self.learning_rate(self._current_progress))
|
||||
|
||||
if not isinstance(optimizers, list):
|
||||
|
|
|
|||
|
|
@ -148,7 +148,6 @@ class PPOPolicy(BasePolicy):
|
|||
self.value_net: 1
|
||||
}[module]
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
# TODO: support linear decay of the learning rate
|
||||
self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate(1), eps=self.adam_epsilon)
|
||||
|
||||
def forward(self, obs, deterministic=False):
|
||||
|
|
|
|||
Loading…
Reference in a new issue