Add save/load for replay buffer

This commit is contained in:
Antonin Raffin 2020-02-05 13:10:02 +01:00
parent 31a862c3a9
commit 75a86881b3
3 changed files with 42 additions and 4 deletions

View file

@ -11,11 +11,13 @@ Breaking Changes:
- Python 2 support was dropped, Torchy Baselines now requires Python 3.6 or above
- Return type of `evaluation.evaluate_policy()` has been changed
- Refactored the replay buffer to avoid transformation between PyTorch and NumPy
- Created `OffPolicyRLModel` base class
New Features:
^^^^^^^^^^^^^
- Add `seed()` method to `VecEnv` class
- Add support for Callback (cf https://github.com/hill-a/stable-baselines/pull/644)
- Add methods for saving and loading replay buffer
Bug Fixes:
^^^^^^^^^^

View file

@ -133,3 +133,25 @@ def test_exclude_include_saved_params(model_class):
# clear file from os
os.remove("test_save.zip")
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_save_load_replay_buffer(model_class):
log_folder = 'logs'
replay_path = os.path.join('logs', 'replay_buffer.pkl')
os.makedirs(log_folder, exist_ok=True)
buffer_size = 10000
model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=buffer_size)
model.learn(500)
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(log_folder)
model.replay_buffer = None
model.load_replay_buffer(replay_path)
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
assert np.allclose(old_replay_buffer.next_observations, model.replay_buffer.next_observations)
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
# clear file from os
os.remove(replay_path)

View file

@ -2,6 +2,7 @@ import time
import os
import io
import zipfile
import pickle
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
from abc import ABC, abstractmethod
from collections import deque
@ -707,11 +708,24 @@ class OffPolicyRLModel(BaseRLModel):
self.replay_buffer = None # type: Optional[ReplayBuffer]
self.use_sde_at_warmup = use_sde_at_warmup
def save_replay_buffer(self):
pass
def save_replay_buffer(self, path: str):
"""
Save the replay buffer as a pickle file.
def load_replay_buffer(self, path):
pass
:param path: (str) Path to a log folder
"""
assert self.replay_buffer is not None, "The replay buffer is not defined"
with open(os.path.join(path, 'replay_buffer.pkl'), 'wb') as file_handler:
pickle.dump(self.replay_buffer, file_handler)
def load_replay_buffer(self, path: str):
"""
:param path: (str) Path to the pickled replay buffer.
"""
with open(path, 'rb') as file_handler:
self.replay_buffer = pickle.load(file_handler)
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'
def collect_rollouts(self,
env: VecEnv,