diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 535030d..5ab772c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -18,6 +18,7 @@ New Features: - Add `seed()` method to `VecEnv` class - Add support for Callback (cf https://github.com/hill-a/stable-baselines/pull/644) - Add methods for saving and loading replay buffer +- Add `extend()` method to the buffers Bug Fixes: ^^^^^^^^^^ diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 38c2f0b..da3f7d0 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -139,7 +139,7 @@ def test_save_load_replay_buffer(model_class): log_folder = 'logs' replay_path = os.path.join('logs', 'replay_buffer.pkl') os.makedirs(log_folder, exist_ok=True) - buffer_size = 10000 + buffer_size = 1000 model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=buffer_size) model.learn(500) old_replay_buffer = deepcopy(model.replay_buffer) @@ -153,5 +153,9 @@ def test_save_load_replay_buffer(model_class): assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards) assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones) + # test extending replay buffer + model.replay_buffer.extend(old_replay_buffer.observations, old_replay_buffer.next_observations, + old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones) + # clear file from os os.remove(replay_path) diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index 39fe24b..0ac3e2f 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -61,6 +61,14 @@ class BaseBuffer(object): """ raise NotImplementedError() + def extend(self, *args, **kwargs) -> None: + """ + Add a new batch of transitions to the buffer + """ + # Do a for loop along the batch axis + for data in zip(*args): + self.add(*data) + def reset(self) -> None: """ Reset the buffer.