From 16703b13143eb2b55e216ac831758d523165ff1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 11 Apr 2022 17:50:02 +0200 Subject: [PATCH] Fix HER goal selection (#848) * Goal sampled from next_achieved_goal instead of achived_goal * No need to have special case for future anymore * Update changelog Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 3 ++- stable_baselines3/her/her_replay_buffer.py | 12 ++---------- stable_baselines3/version.txt | 2 +- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b209f16..652b5a6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a1 (WIP) +Release 1.5.1a2(WIP) --------------------------- Breaking Changes: @@ -21,6 +21,7 @@ SB3-Contrib Bug Fixes: ^^^^^^^^^^ - Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) +- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 9a41477..f61a786 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -252,7 +252,7 @@ class HerReplayBuffer(DictReplayBuffer): elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: # replay with random state which comes from the same episode and was observed after current transition transitions_indices = np.random.randint( - transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices] + transitions_indices[her_indices], self.episode_lengths[her_episode_indices] ) elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: @@ -262,7 +262,7 @@ class HerReplayBuffer(DictReplayBuffer): else: raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!") - return self._buffer["achieved_goal"][her_episode_indices, transitions_indices] + return self._buffer["next_achieved_goal"][her_episode_indices, transitions_indices] def _sample_transitions( self, @@ -304,14 +304,6 @@ class HerReplayBuffer(DictReplayBuffer): ep_lengths = self.episode_lengths[episode_indices] - # Special case when using the "future" goal sampling strategy - # we cannot sample all transitions, we have to remove the last timestep - if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: - # restrict the sampling domain when ep_lengths > 1 - # otherwise filter out the indices - her_indices = her_indices[ep_lengths[her_indices] > 1] - ep_lengths[her_indices] -= 1 - if online_sampling: # Select which transitions to use transitions_indices = np.random.randint(ep_lengths) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 1110517..1a2eef7 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a1 +1.5.1a2