mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-06 00:03:28 +00:00
Fix bug with full HerReplayBuffer (#236)
* Fix bug with full replay buffer * Updated changelog * Update tests/test_her.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
d04aad2a20
commit
852961139e
3 changed files with 36 additions and 1 deletions
|
|
@ -30,6 +30,7 @@ Bug Fixes:
|
|||
- Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv)
|
||||
- Fixed ``DQN`` predict method when using single ``gym.Env`` with ``deterministic=False``
|
||||
- Fixed bug that the arguments order of ``explained_variance()`` in ``ppo.py`` and ``a2c.py`` is not correct (@thisray)
|
||||
- Fixed bug where full ``HerReplayBuffer`` leads to an index error. (@megan-klaiber)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -211,7 +211,13 @@ class HerReplayBuffer(ReplayBuffer):
|
|||
# Select which episodes to use
|
||||
if online_sampling:
|
||||
assert batch_size is not None, "No batch_size specified for online sampling of HER transitions"
|
||||
episode_indices = np.random.randint(0, self.n_episodes_stored, batch_size)
|
||||
# Do not sample the episode with index `self.pos` as the episode is invalid
|
||||
if self.full:
|
||||
episode_indices = (
|
||||
np.random.randint(1, self.n_episodes_stored, batch_size) + self.pos
|
||||
) % self.n_episodes_stored
|
||||
else:
|
||||
episode_indices = np.random.randint(0, self.n_episodes_stored, batch_size)
|
||||
# A subset of the transitions will be relabeled using HER algorithm
|
||||
her_indices = np.arange(batch_size)[: int(self.her_ratio * batch_size)]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -250,6 +250,34 @@ def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_la
|
|||
model.learn(200, reset_num_timesteps=reset_num_timesteps)
|
||||
|
||||
|
||||
def test_full_replay_buffer():
|
||||
"""
|
||||
Test if HER works correctly with a full replay buffer when using online sampling.
|
||||
It should not sample the current episode which is not finished.
|
||||
"""
|
||||
n_bits = 4
|
||||
env = BitFlippingEnv(n_bits=n_bits, continuous=True)
|
||||
|
||||
# use small buffer size to get the buffer full
|
||||
model = HER(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
SAC,
|
||||
goal_selection_strategy="future",
|
||||
online_sampling=True,
|
||||
gradient_steps=1,
|
||||
train_freq=1,
|
||||
n_episodes_rollout=-1,
|
||||
max_episode_length=n_bits,
|
||||
policy_kwargs=dict(net_arch=[64]),
|
||||
learning_starts=1,
|
||||
buffer_size=20,
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=100)
|
||||
|
||||
|
||||
def test_get_max_episode_length():
|
||||
dict_env = DummyVecEnv([lambda: BitFlippingEnv()])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue