mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Update PPO KL Divergence Estimator (#419)
* remove unused all_kl_divs memory * new kl approximate equation * move kl check before update step * update changelog * add continue_training flag update to kl check * add verbose check * update changelog * lint with black * r -> log_ratio * Add link to PR * invert ratio * Fix for Sphinx v4.0 Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
35da0b59b9
commit
2ada2dd0b2
3 changed files with 22 additions and 6 deletions
|
|
@ -10,6 +10,8 @@ Release 1.1.0a5 (WIP)
|
|||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Renamed ``_last_dones`` and ``dones`` to ``_last_episode_starts`` and ``episode_starts`` in ``RolloutBuffer``.
|
||||
- Updated the KL Divergence estimator in the PPO algorithm to be positive definite and have lower variance (@09tangriro)
|
||||
- Updated the KL Divergence check in the PPO algorithm to be before the gradient update step rather than after end of epoch (@09tangriro)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -167,3 +167,4 @@ DDPG Policies
|
|||
|
||||
.. autoclass:: CnnPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
|
|
|||
|
|
@ -168,10 +168,12 @@ class PPO(OnPolicyAlgorithm):
|
|||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||
|
||||
entropy_losses, all_kl_divs = [], []
|
||||
entropy_losses = []
|
||||
pg_losses, value_losses = [], []
|
||||
clip_fractions = []
|
||||
|
||||
continue_training = True
|
||||
|
||||
# train for n_epochs epochs
|
||||
for epoch in range(self.n_epochs):
|
||||
approx_kl_divs = []
|
||||
|
|
@ -231,18 +233,29 @@ class PPO(OnPolicyAlgorithm):
|
|||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
# Calculate approximate form of reverse KL Divergence for early stopping
|
||||
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
|
||||
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
|
||||
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
with th.no_grad():
|
||||
log_ratio = log_prob - rollout_data.old_log_prob
|
||||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||||
approx_kl_divs.append(approx_kl_div)
|
||||
|
||||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||||
continue_training = False
|
||||
if self.verbose >= 1:
|
||||
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||||
break
|
||||
|
||||
# Optimization step
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy())
|
||||
|
||||
all_kl_divs.append(np.mean(approx_kl_divs))
|
||||
|
||||
if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl:
|
||||
print(f"Early stopping at step {epoch} due to reaching max kl: {np.mean(approx_kl_divs):.2f}")
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
self._n_updates += self.n_epochs
|
||||
|
|
|
|||
Loading…
Reference in a new issue