diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9dee356..acfa5e3 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -26,6 +26,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed Atari wrapper that missed the reset condition (@luizapozzobon) +- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index c934527..3ea6756 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -189,7 +189,6 @@ class PPO(OnPolicyAlgorithm): clip_fractions = [] continue_training = True - # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] @@ -271,10 +270,10 @@ class PPO(OnPolicyAlgorithm): th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() + self._n_updates += 1 if not continue_training: break - self._n_updates += self.n_epochs explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs