Update PPO KL Divergence Estimator (#419)

* remove unused all_kl_divs memory

* new kl approximate equation

* move kl check before update step

* update changelog

* add continue_training flag update to kl check

* add verbose check

* update changelog

* lint with black

* r -> log_ratio

* Add link to PR

* invert ratio

* Fix for Sphinx v4.0

Co-authored-by: Anssi <kaneran21@hotmail.com>
Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Rohan Tangri 2021-05-10 11:21:00 +01:00 committed by GitHub
parent 35da0b59b9
commit 2ada2dd0b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 6 deletions

View file

@ -10,6 +10,8 @@ Release 1.1.0a5 (WIP)
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Renamed ``_last_dones`` and ``dones`` to ``_last_episode_starts`` and ``episode_starts`` in ``RolloutBuffer``.
- Updated the KL Divergence estimator in the PPO algorithm to be positive definite and have lower variance (@09tangriro)
- Updated the KL Divergence check in the PPO algorithm to be before the gradient update step rather than after end of epoch (@09tangriro)
New Features:
^^^^^^^^^^^^^

View file

@ -167,3 +167,4 @@ DDPG Policies
.. autoclass:: CnnPolicy
:members:
:noindex:

View file

@ -168,10 +168,12 @@ class PPO(OnPolicyAlgorithm):
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
entropy_losses, all_kl_divs = [], []
entropy_losses = []
pg_losses, value_losses = [], []
clip_fractions = []
continue_training = True
# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
@ -231,18 +233,29 @@ class PPO(OnPolicyAlgorithm):
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
# and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy())
all_kl_divs.append(np.mean(approx_kl_divs))
if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl:
print(f"Early stopping at step {epoch} due to reaching max kl: {np.mean(approx_kl_divs):.2f}")
if not continue_training:
break
self._n_updates += self.n_epochs