diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3a2613a..285cbc5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -10,6 +10,8 @@ Release 1.1.0a5 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ - Renamed ``_last_dones`` and ``dones`` to ``_last_episode_starts`` and ``episode_starts`` in ``RolloutBuffer``. +- Updated the KL Divergence estimator in the PPO algorithm to be positive definite and have lower variance (@09tangriro) +- Updated the KL Divergence check in the PPO algorithm to be before the gradient update step rather than after end of epoch (@09tangriro) New Features: ^^^^^^^^^^^^^ diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst index 08f5bf4..75087b1 100644 --- a/docs/modules/ddpg.rst +++ b/docs/modules/ddpg.rst @@ -167,3 +167,4 @@ DDPG Policies .. autoclass:: CnnPolicy :members: + :noindex: diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index ca22da6..34737c7 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -168,10 +168,12 @@ class PPO(OnPolicyAlgorithm): if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf(self._current_progress_remaining) - entropy_losses, all_kl_divs = [], [] + entropy_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] + continue_training = True + # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] @@ -231,18 +233,29 @@ class PPO(OnPolicyAlgorithm): loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() - approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy()) - all_kl_divs.append(np.mean(approx_kl_divs)) - - if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl: - print(f"Early stopping at step {epoch} due to reaching max kl: {np.mean(approx_kl_divs):.2f}") + if not continue_training: break self._n_updates += self.n_epochs