From 2ada2dd0b2b10f71d0be38ce032d49391b875c4e Mon Sep 17 00:00:00 2001 From: Rohan Tangri <47002898+09tangriro@users.noreply.github.com> Date: Mon, 10 May 2021 11:21:00 +0100 Subject: [PATCH] Update PPO KL Divergence Estimator (#419) * remove unused all_kl_divs memory * new kl approximate equation * move kl check before update step * update changelog * add continue_training flag update to kl check * add verbose check * update changelog * lint with black * r -> log_ratio * Add link to PR * invert ratio * Fix for Sphinx v4.0 Co-authored-by: Anssi Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 2 ++ docs/modules/ddpg.rst | 1 + stable_baselines3/ppo/ppo.py | 25 +++++++++++++++++++------ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3a2613a..285cbc5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -10,6 +10,8 @@ Release 1.1.0a5 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ - Renamed ``_last_dones`` and ``dones`` to ``_last_episode_starts`` and ``episode_starts`` in ``RolloutBuffer``. +- Updated the KL Divergence estimator in the PPO algorithm to be positive definite and have lower variance (@09tangriro) +- Updated the KL Divergence check in the PPO algorithm to be before the gradient update step rather than after end of epoch (@09tangriro) New Features: ^^^^^^^^^^^^^ diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst index 08f5bf4..75087b1 100644 --- a/docs/modules/ddpg.rst +++ b/docs/modules/ddpg.rst @@ -167,3 +167,4 @@ DDPG Policies .. autoclass:: CnnPolicy :members: + :noindex: diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index ca22da6..34737c7 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -168,10 +168,12 @@ class PPO(OnPolicyAlgorithm): if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf(self._current_progress_remaining) - entropy_losses, all_kl_divs = [], [] + entropy_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] + continue_training = True + # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] @@ -231,18 +233,29 @@ class PPO(OnPolicyAlgorithm): loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() - approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy()) - all_kl_divs.append(np.mean(approx_kl_divs)) - - if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl: - print(f"Early stopping at step {epoch} due to reaching max kl: {np.mean(approx_kl_divs):.2f}") + if not continue_training: break self._n_updates += self.n_epochs