From 411ff697dde31df9ff914ba0c335875d1bd5998d Mon Sep 17 00:00:00 2001 From: adamfrly <45516720+adamfrly@users.noreply.github.com> Date: Mon, 6 Feb 2023 09:48:41 -0500 Subject: [PATCH] Ensure train/n_updates metric accounts for early stopping of training loop (#1311) * Correct _n_updates when target_kl stops loop early * Update changelog * Simplify code --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 1 + stable_baselines3/ppo/ppo.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9dee356..acfa5e3 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -26,6 +26,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed Atari wrapper that missed the reset condition (@luizapozzobon) +- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index c934527..3ea6756 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -189,7 +189,6 @@ class PPO(OnPolicyAlgorithm): clip_fractions = [] continue_training = True - # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] @@ -271,10 +270,10 @@ class PPO(OnPolicyAlgorithm): th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() + self._n_updates += 1 if not continue_training: break - self._n_updates += self.n_epochs explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs