From 852961139e3314208ea628c8238a1e5efd04e759 Mon Sep 17 00:00:00 2001 From: Megan Klaiber Date: Fri, 20 Nov 2020 13:23:03 +0100 Subject: [PATCH] Fix bug with full HerReplayBuffer (#236) * Fix bug with full replay buffer * Updated changelog * Update tests/test_her.py Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 1 + stable_baselines3/her/her_replay_buffer.py | 8 ++++++- tests/test_her.py | 28 ++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index ee5c331..a47bff9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -30,6 +30,7 @@ Bug Fixes: - Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv) - Fixed ``DQN`` predict method when using single ``gym.Env`` with ``deterministic=False`` - Fixed bug that the arguments order of ``explained_variance()`` in ``ppo.py`` and ``a2c.py`` is not correct (@thisray) +- Fixed bug where full ``HerReplayBuffer`` leads to an index error. (@megan-klaiber) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 5bbf1b9..edca50f 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -211,7 +211,13 @@ class HerReplayBuffer(ReplayBuffer): # Select which episodes to use if online_sampling: assert batch_size is not None, "No batch_size specified for online sampling of HER transitions" - episode_indices = np.random.randint(0, self.n_episodes_stored, batch_size) + # Do not sample the episode with index `self.pos` as the episode is invalid + if self.full: + episode_indices = ( + np.random.randint(1, self.n_episodes_stored, batch_size) + self.pos + ) % self.n_episodes_stored + else: + episode_indices = np.random.randint(0, self.n_episodes_stored, batch_size) # A subset of the transitions will be relabeled using HER algorithm her_indices = np.arange(batch_size)[: int(self.her_ratio * batch_size)] else: diff --git a/tests/test_her.py b/tests/test_her.py index 09d1a78..5e9cccc 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -250,6 +250,34 @@ def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_la model.learn(200, reset_num_timesteps=reset_num_timesteps) +def test_full_replay_buffer(): + """ + Test if HER works correctly with a full replay buffer when using online sampling. + It should not sample the current episode which is not finished. + """ + n_bits = 4 + env = BitFlippingEnv(n_bits=n_bits, continuous=True) + + # use small buffer size to get the buffer full + model = HER( + "MlpPolicy", + env, + SAC, + goal_selection_strategy="future", + online_sampling=True, + gradient_steps=1, + train_freq=1, + n_episodes_rollout=-1, + max_episode_length=n_bits, + policy_kwargs=dict(net_arch=[64]), + learning_starts=1, + buffer_size=20, + verbose=1, + ) + + model.learn(total_timesteps=100) + + def test_get_max_episode_length(): dict_env = DummyVecEnv([lambda: BitFlippingEnv()])