Ensure train/n_updates metric accounts for early stopping of training loop (#1311)

* Correct _n_updates when target_kl stops loop early

* Update changelog

* Simplify code

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
adamfrly 2023-02-06 09:48:41 -05:00 committed by GitHub
parent d0c1a87faf
commit 411ff697dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View file

@ -26,6 +26,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
Deprecations:
^^^^^^^^^^^^^

View file

@ -189,7 +189,6 @@ class PPO(OnPolicyAlgorithm):
clip_fractions = []
continue_training = True
# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
@ -271,10 +270,10 @@ class PPO(OnPolicyAlgorithm):
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
self._n_updates += 1
if not continue_training:
break
self._n_updates += self.n_epochs
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
# Logs