From 08e7519381e800edc6bbd09577f14381b7341873 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 29 Jun 2020 17:58:55 +0200 Subject: [PATCH] Fix q-target in SAC (#77) * Fix q-target in SAC * [ci skip] Update version --- docs/misc/changelog.rst | 3 ++- stable_baselines3/sac/sac.py | 5 ++--- stable_baselines3/version.txt | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a2c9645..6725594 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.8.0a1 (WIP) +Pre-Release 0.8.0a2 (WIP) ------------------------------ Breaking Changes: @@ -21,6 +21,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed a bug in the ``close()`` method of ``SubprocVecEnv``, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended) +- Fix target for updating q values in SAC: the entropy term was not conditioned by terminals states Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 99f8888..203abc4 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -202,10 +202,9 @@ class SAC(OffPolicyAlgorithm): next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) # Compute the target Q value target_q1, target_q2 = self.critic_target(replay_data.next_observations, next_actions) - target_q = th.min(target_q1, target_q2) - target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + target_q = th.min(target_q1, target_q2) - ent_coef * next_log_prob.reshape(-1, 1) # td error + entropy term - q_backup = target_q - ent_coef * next_log_prob.reshape(-1, 1) + q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q # Get current Q estimates # using action from the replay buffer diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 6685a73..8db4718 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.8.0a1 +0.8.0a2