From df1e7aa0002b0c380d71d8978b63569a93061c53 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 28 Oct 2019 17:42:39 +0100 Subject: [PATCH] Add docstring --- torchy_baselines/common/base_class.py | 8 ++++++++ torchy_baselines/ppo/policies.py | 1 - 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 0904932..6adc45c 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -129,6 +129,14 @@ class BaseRLModel(object): self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps) def _update_learning_rate(self, optimizers): + """ + Update the optimizers learning rate using the current learning rate schedule + and the current progress (from 1 to 0). + + :param optimizers: ([th.optim.Optimizer] or Optimizer) An optimizer + or a list of optimizer. + """ + # Log the current learning rate logger.logkv("learning_rate", self.learning_rate(self._current_progress)) if not isinstance(optimizers, list): diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index c967962..e973858 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -148,7 +148,6 @@ class PPOPolicy(BasePolicy): self.value_net: 1 }[module] module.apply(partial(self.init_weights, gain=gain)) - # TODO: support linear decay of the learning rate self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate(1), eps=self.adam_epsilon) def forward(self, obs, deterministic=False):