mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-02 03:55:39 +00:00
Fix HER goal selection (#848)
* Goal sampled from next_achieved_goal instead of achived_goal * No need to have special case for future anymore * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
254bb10c42
commit
16703b1314
3 changed files with 5 additions and 12 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.5.1a1 (WIP)
|
||||
Release 1.5.1a2(WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -21,6 +21,7 @@ SB3-Contrib
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
|
||||
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -252,7 +252,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
|
||||
# replay with random state which comes from the same episode and was observed after current transition
|
||||
transitions_indices = np.random.randint(
|
||||
transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices]
|
||||
transitions_indices[her_indices], self.episode_lengths[her_episode_indices]
|
||||
)
|
||||
|
||||
elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE:
|
||||
|
|
@ -262,7 +262,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
else:
|
||||
raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!")
|
||||
|
||||
return self._buffer["achieved_goal"][her_episode_indices, transitions_indices]
|
||||
return self._buffer["next_achieved_goal"][her_episode_indices, transitions_indices]
|
||||
|
||||
def _sample_transitions(
|
||||
self,
|
||||
|
|
@ -304,14 +304,6 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
|
||||
ep_lengths = self.episode_lengths[episode_indices]
|
||||
|
||||
# Special case when using the "future" goal sampling strategy
|
||||
# we cannot sample all transitions, we have to remove the last timestep
|
||||
if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
|
||||
# restrict the sampling domain when ep_lengths > 1
|
||||
# otherwise filter out the indices
|
||||
her_indices = her_indices[ep_lengths[her_indices] > 1]
|
||||
ep_lengths[her_indices] -= 1
|
||||
|
||||
if online_sampling:
|
||||
# Select which transitions to use
|
||||
transitions_indices = np.random.randint(ep_lengths)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.1a1
|
||||
1.5.1a2
|
||||
|
|
|
|||
Loading…
Reference in a new issue