mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-02 03:55:39 +00:00
Ensure train/n_updates metric accounts for early stopping of training loop (#1311)
* Correct _n_updates when target_kl stops loop early * Update changelog * Simplify code --------- Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
d0c1a87faf
commit
411ff697dd
2 changed files with 2 additions and 2 deletions
|
|
@ -26,6 +26,7 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
|
||||
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -189,7 +189,6 @@ class PPO(OnPolicyAlgorithm):
|
|||
clip_fractions = []
|
||||
|
||||
continue_training = True
|
||||
|
||||
# train for n_epochs epochs
|
||||
for epoch in range(self.n_epochs):
|
||||
approx_kl_divs = []
|
||||
|
|
@ -271,10 +270,10 @@ class PPO(OnPolicyAlgorithm):
|
|||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
self._n_updates += 1
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
self._n_updates += self.n_epochs
|
||||
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
||||
|
||||
# Logs
|
||||
|
|
|
|||
Loading…
Reference in a new issue