Add extend method

This commit is contained in:
Antonin Raffin 2020-02-11 16:40:44 +01:00
parent 8eb82c86e3
commit b7dcc8d58e
3 changed files with 14 additions and 1 deletions

View file

@ -18,6 +18,7 @@ New Features:
- Add `seed()` method to `VecEnv` class
- Add support for Callback (cf https://github.com/hill-a/stable-baselines/pull/644)
- Add methods for saving and loading replay buffer
- Add `extend()` method to the buffers
Bug Fixes:
^^^^^^^^^^

View file

@ -139,7 +139,7 @@ def test_save_load_replay_buffer(model_class):
log_folder = 'logs'
replay_path = os.path.join('logs', 'replay_buffer.pkl')
os.makedirs(log_folder, exist_ok=True)
buffer_size = 10000
buffer_size = 1000
model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=buffer_size)
model.learn(500)
old_replay_buffer = deepcopy(model.replay_buffer)
@ -153,5 +153,9 @@ def test_save_load_replay_buffer(model_class):
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
# test extending replay buffer
model.replay_buffer.extend(old_replay_buffer.observations, old_replay_buffer.next_observations,
old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones)
# clear file from os
os.remove(replay_path)

View file

@ -61,6 +61,14 @@ class BaseBuffer(object):
"""
raise NotImplementedError()
def extend(self, *args, **kwargs) -> None:
"""
Add a new batch of transitions to the buffer
"""
# Do a for loop along the batch axis
for data in zip(*args):
self.add(*data)
def reset(self) -> None:
"""
Reset the buffer.