From 8f9aaaebe9909c38622c429f1b82e1af39fdbb24 Mon Sep 17 00:00:00 2001 From: Andy Shih Date: Wed, 29 Jul 2020 12:19:41 -0700 Subject: [PATCH] fix approximate entropy calculation in PPO and A2C (#130) --- docs/misc/changelog.rst | 3 ++- stable_baselines3/a2c/a2c.py | 2 +- stable_baselines3/ppo/ppo.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index be9a7f3..dfa728a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -32,6 +32,7 @@ Bug Fixes: - Fix target for updating q values in SAC: the entropy term was not conditioned by terminals states - Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang) - Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37) +- Fixed approximate entropy calculation in PPO and A2C. (@andyshih12) Deprecations: ^^^^^^^^^^^^^ @@ -357,4 +358,4 @@ And all the contributors: @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 -@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 +@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index c2c7b34..cc1d078 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -141,7 +141,7 @@ class A2C(OnPolicyAlgorithm): # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form - entropy_loss = -log_prob.mean() + entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index eadf961..cc191d2 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -198,7 +198,7 @@ class PPO(OnPolicyAlgorithm): # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form - entropy_loss = -log_prob.mean() + entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy)