From 75a86881b319e6d7d6bbba32c121edff34e3ffed Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 5 Feb 2020 13:10:02 +0100 Subject: [PATCH] Add save/load for replay buffer --- docs/misc/changelog.rst | 2 ++ tests/test_save_load.py | 22 ++++++++++++++++++++++ torchy_baselines/common/base_class.py | 22 ++++++++++++++++++---- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 708a83a..535030d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -11,11 +11,13 @@ Breaking Changes: - Python 2 support was dropped, Torchy Baselines now requires Python 3.6 or above - Return type of `evaluation.evaluate_policy()` has been changed - Refactored the replay buffer to avoid transformation between PyTorch and NumPy +- Created `OffPolicyRLModel` base class New Features: ^^^^^^^^^^^^^ - Add `seed()` method to `VecEnv` class - Add support for Callback (cf https://github.com/hill-a/stable-baselines/pull/644) +- Add methods for saving and loading replay buffer Bug Fixes: ^^^^^^^^^^ diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 5171042..38c2f0b 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -133,3 +133,25 @@ def test_exclude_include_saved_params(model_class): # clear file from os os.remove("test_save.zip") + +@pytest.mark.parametrize("model_class", [SAC, TD3]) +def test_save_load_replay_buffer(model_class): + log_folder = 'logs' + replay_path = os.path.join('logs', 'replay_buffer.pkl') + os.makedirs(log_folder, exist_ok=True) + buffer_size = 10000 + model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=buffer_size) + model.learn(500) + old_replay_buffer = deepcopy(model.replay_buffer) + model.save_replay_buffer(log_folder) + model.replay_buffer = None + model.load_replay_buffer(replay_path) + + assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations) + assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions) + assert np.allclose(old_replay_buffer.next_observations, model.replay_buffer.next_observations) + assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards) + assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones) + + # clear file from os + os.remove(replay_path) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index ee6ec7f..68c9025 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -2,6 +2,7 @@ import time import os import io import zipfile +import pickle from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable from abc import ABC, abstractmethod from collections import deque @@ -707,11 +708,24 @@ class OffPolicyRLModel(BaseRLModel): self.replay_buffer = None # type: Optional[ReplayBuffer] self.use_sde_at_warmup = use_sde_at_warmup - def save_replay_buffer(self): - pass + def save_replay_buffer(self, path: str): + """ + Save the replay buffer as a pickle file. - def load_replay_buffer(self, path): - pass + :param path: (str) Path to a log folder + """ + assert self.replay_buffer is not None, "The replay buffer is not defined" + with open(os.path.join(path, 'replay_buffer.pkl'), 'wb') as file_handler: + pickle.dump(self.replay_buffer, file_handler) + + def load_replay_buffer(self, path: str): + """ + + :param path: (str) Path to the pickled replay buffer. + """ + with open(path, 'rb') as file_handler: + self.replay_buffer = pickle.load(file_handler) + assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class' def collect_rollouts(self, env: VecEnv,