diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b209f16..652b5a6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a1 (WIP) +Release 1.5.1a2(WIP) --------------------------- Breaking Changes: @@ -21,6 +21,7 @@ SB3-Contrib Bug Fixes: ^^^^^^^^^^ - Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) +- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 9a41477..f61a786 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -252,7 +252,7 @@ class HerReplayBuffer(DictReplayBuffer): elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: # replay with random state which comes from the same episode and was observed after current transition transitions_indices = np.random.randint( - transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices] + transitions_indices[her_indices], self.episode_lengths[her_episode_indices] ) elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: @@ -262,7 +262,7 @@ class HerReplayBuffer(DictReplayBuffer): else: raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!") - return self._buffer["achieved_goal"][her_episode_indices, transitions_indices] + return self._buffer["next_achieved_goal"][her_episode_indices, transitions_indices] def _sample_transitions( self, @@ -304,14 +304,6 @@ class HerReplayBuffer(DictReplayBuffer): ep_lengths = self.episode_lengths[episode_indices] - # Special case when using the "future" goal sampling strategy - # we cannot sample all transitions, we have to remove the last timestep - if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: - # restrict the sampling domain when ep_lengths > 1 - # otherwise filter out the indices - her_indices = her_indices[ep_lengths[her_indices] > 1] - ep_lengths[her_indices] -= 1 - if online_sampling: # Select which transitions to use transitions_indices = np.random.randint(ep_lengths) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 1110517..1a2eef7 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a1 +1.5.1a2