diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 0904932..6adc45c 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -129,6 +129,14 @@ class BaseRLModel(object): self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps) def _update_learning_rate(self, optimizers): + """ + Update the optimizers learning rate using the current learning rate schedule + and the current progress (from 1 to 0). + + :param optimizers: ([th.optim.Optimizer] or Optimizer) An optimizer + or a list of optimizer. + """ + # Log the current learning rate logger.logkv("learning_rate", self.learning_rate(self._current_progress)) if not isinstance(optimizers, list): diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index c967962..e973858 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -148,7 +148,6 @@ class PPOPolicy(BasePolicy): self.value_net: 1 }[module] module.apply(partial(self.init_weights, gain=gain)) - # TODO: support linear decay of the learning rate self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate(1), eps=self.adam_epsilon) def forward(self, obs, deterministic=False):