mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-07 00:13:37 +00:00
Add save/load for replay buffer
This commit is contained in:
parent
31a862c3a9
commit
75a86881b3
3 changed files with 42 additions and 4 deletions
|
|
@ -11,11 +11,13 @@ Breaking Changes:
|
|||
- Python 2 support was dropped, Torchy Baselines now requires Python 3.6 or above
|
||||
- Return type of `evaluation.evaluate_policy()` has been changed
|
||||
- Refactored the replay buffer to avoid transformation between PyTorch and NumPy
|
||||
- Created `OffPolicyRLModel` base class
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Add `seed()` method to `VecEnv` class
|
||||
- Add support for Callback (cf https://github.com/hill-a/stable-baselines/pull/644)
|
||||
- Add methods for saving and loading replay buffer
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -133,3 +133,25 @@ def test_exclude_include_saved_params(model_class):
|
|||
|
||||
# clear file from os
|
||||
os.remove("test_save.zip")
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3])
|
||||
def test_save_load_replay_buffer(model_class):
|
||||
log_folder = 'logs'
|
||||
replay_path = os.path.join('logs', 'replay_buffer.pkl')
|
||||
os.makedirs(log_folder, exist_ok=True)
|
||||
buffer_size = 10000
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=buffer_size)
|
||||
model.learn(500)
|
||||
old_replay_buffer = deepcopy(model.replay_buffer)
|
||||
model.save_replay_buffer(log_folder)
|
||||
model.replay_buffer = None
|
||||
model.load_replay_buffer(replay_path)
|
||||
|
||||
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
|
||||
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
|
||||
assert np.allclose(old_replay_buffer.next_observations, model.replay_buffer.next_observations)
|
||||
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
|
||||
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
|
||||
|
||||
# clear file from os
|
||||
os.remove(replay_path)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import time
|
|||
import os
|
||||
import io
|
||||
import zipfile
|
||||
import pickle
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
|
|
@ -707,11 +708,24 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def save_replay_buffer(self):
|
||||
pass
|
||||
def save_replay_buffer(self, path: str):
|
||||
"""
|
||||
Save the replay buffer as a pickle file.
|
||||
|
||||
def load_replay_buffer(self, path):
|
||||
pass
|
||||
:param path: (str) Path to a log folder
|
||||
"""
|
||||
assert self.replay_buffer is not None, "The replay buffer is not defined"
|
||||
with open(os.path.join(path, 'replay_buffer.pkl'), 'wb') as file_handler:
|
||||
pickle.dump(self.replay_buffer, file_handler)
|
||||
|
||||
def load_replay_buffer(self, path: str):
|
||||
"""
|
||||
|
||||
:param path: (str) Path to the pickled replay buffer.
|
||||
"""
|
||||
with open(path, 'rb') as file_handler:
|
||||
self.replay_buffer = pickle.load(file_handler)
|
||||
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'
|
||||
|
||||
def collect_rollouts(self,
|
||||
env: VecEnv,
|
||||
|
|
|
|||
Loading…
Reference in a new issue