mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-02 23:40:09 +00:00
Add extend method
This commit is contained in:
parent
8eb82c86e3
commit
b7dcc8d58e
3 changed files with 14 additions and 1 deletions
|
|
@ -18,6 +18,7 @@ New Features:
|
|||
- Add `seed()` method to `VecEnv` class
|
||||
- Add support for Callback (cf https://github.com/hill-a/stable-baselines/pull/644)
|
||||
- Add methods for saving and loading replay buffer
|
||||
- Add `extend()` method to the buffers
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ def test_save_load_replay_buffer(model_class):
|
|||
log_folder = 'logs'
|
||||
replay_path = os.path.join('logs', 'replay_buffer.pkl')
|
||||
os.makedirs(log_folder, exist_ok=True)
|
||||
buffer_size = 10000
|
||||
buffer_size = 1000
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=buffer_size)
|
||||
model.learn(500)
|
||||
old_replay_buffer = deepcopy(model.replay_buffer)
|
||||
|
|
@ -153,5 +153,9 @@ def test_save_load_replay_buffer(model_class):
|
|||
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
|
||||
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
|
||||
|
||||
# test extending replay buffer
|
||||
model.replay_buffer.extend(old_replay_buffer.observations, old_replay_buffer.next_observations,
|
||||
old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones)
|
||||
|
||||
# clear file from os
|
||||
os.remove(replay_path)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,14 @@ class BaseBuffer(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def extend(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Add a new batch of transitions to the buffer
|
||||
"""
|
||||
# Do a for loop along the batch axis
|
||||
for data in zip(*args):
|
||||
self.add(*data)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the buffer.
|
||||
|
|
|
|||
Loading…
Reference in a new issue