Fix HER goal selection (#848)

* Goal sampled from next_achieved_goal instead of achived_goal

* No need to have special case for future anymore

* Update changelog

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-04-11 17:50:02 +02:00 committed by GitHub
parent 254bb10c42
commit 16703b1314
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 12 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.5.1a1 (WIP)
Release 1.5.1a2(WIP)
---------------------------
Breaking Changes:
@ -21,6 +21,7 @@ SB3-Contrib
Bug Fixes:
^^^^^^^^^^
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
Deprecations:
^^^^^^^^^^^^^

View file

@ -252,7 +252,7 @@ class HerReplayBuffer(DictReplayBuffer):
elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
# replay with random state which comes from the same episode and was observed after current transition
transitions_indices = np.random.randint(
transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices]
transitions_indices[her_indices], self.episode_lengths[her_episode_indices]
)
elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE:
@ -262,7 +262,7 @@ class HerReplayBuffer(DictReplayBuffer):
else:
raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!")
return self._buffer["achieved_goal"][her_episode_indices, transitions_indices]
return self._buffer["next_achieved_goal"][her_episode_indices, transitions_indices]
def _sample_transitions(
self,
@ -304,14 +304,6 @@ class HerReplayBuffer(DictReplayBuffer):
ep_lengths = self.episode_lengths[episode_indices]
# Special case when using the "future" goal sampling strategy
# we cannot sample all transitions, we have to remove the last timestep
if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
# restrict the sampling domain when ep_lengths > 1
# otherwise filter out the indices
her_indices = her_indices[ep_lengths[her_indices] > 1]
ep_lengths[her_indices] -= 1
if online_sampling:
# Select which transitions to use
transitions_indices = np.random.randint(ep_lengths)

View file

@ -1 +1 @@
1.5.1a1
1.5.1a2