Fix bug with full HerReplayBuffer (#236)

* Fix bug with full replay buffer

* Updated changelog

* Update tests/test_her.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Megan Klaiber 2020-11-20 13:23:03 +01:00 committed by GitHub
parent d04aad2a20
commit 852961139e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 1 deletions

View file

@ -30,6 +30,7 @@ Bug Fixes:
- Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv)
- Fixed ``DQN`` predict method when using single ``gym.Env`` with ``deterministic=False``
- Fixed bug that the arguments order of ``explained_variance()`` in ``ppo.py`` and ``a2c.py`` is not correct (@thisray)
- Fixed bug where full ``HerReplayBuffer`` leads to an index error. (@megan-klaiber)
Deprecations:
^^^^^^^^^^^^^

View file

@ -211,7 +211,13 @@ class HerReplayBuffer(ReplayBuffer):
# Select which episodes to use
if online_sampling:
assert batch_size is not None, "No batch_size specified for online sampling of HER transitions"
episode_indices = np.random.randint(0, self.n_episodes_stored, batch_size)
# Do not sample the episode with index `self.pos` as the episode is invalid
if self.full:
episode_indices = (
np.random.randint(1, self.n_episodes_stored, batch_size) + self.pos
) % self.n_episodes_stored
else:
episode_indices = np.random.randint(0, self.n_episodes_stored, batch_size)
# A subset of the transitions will be relabeled using HER algorithm
her_indices = np.arange(batch_size)[: int(self.her_ratio * batch_size)]
else:

View file

@ -250,6 +250,34 @@ def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_la
model.learn(200, reset_num_timesteps=reset_num_timesteps)
def test_full_replay_buffer():
"""
Test if HER works correctly with a full replay buffer when using online sampling.
It should not sample the current episode which is not finished.
"""
n_bits = 4
env = BitFlippingEnv(n_bits=n_bits, continuous=True)
# use small buffer size to get the buffer full
model = HER(
"MlpPolicy",
env,
SAC,
goal_selection_strategy="future",
online_sampling=True,
gradient_steps=1,
train_freq=1,
n_episodes_rollout=-1,
max_episode_length=n_bits,
policy_kwargs=dict(net_arch=[64]),
learning_starts=1,
buffer_size=20,
verbose=1,
)
model.learn(total_timesteps=100)
def test_get_max_episode_length():
dict_env = DummyVecEnv([lambda: BitFlippingEnv()])