diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 84c8a3d..fb1a9b1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -8,12 +8,14 @@ Pre-Release 0.8.0a0 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ +- ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi) New Features: ^^^^^^^^^^^^^ Bug Fixes: ^^^^^^^^^^ +- Fixed a bug in the ``close()`` method of ``SubprocVecEnv``, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended) Deprecations: ^^^^^^^^^^^^^ @@ -323,3 +325,4 @@ And all the contributors: @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 +@tirafesi diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 6707dd7..6794971 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -106,10 +106,10 @@ class OffPolicyAlgorithm(BaseAlgorithm): """ Save the replay buffer as a pickle file. - :param path: (str) Path to a log folder + :param path: (str) Path to the file where the replay buffer should be saved """ assert self.replay_buffer is not None, "The replay buffer is not defined" - with open(os.path.join(path, 'replay_buffer.pkl'), 'wb') as file_handler: + with open(path, 'wb') as file_handler: pickle.dump(self.replay_buffer, file_handler) def load_replay_buffer(self, path: str): diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 3fcb4a5..b12218f 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -29,6 +29,7 @@ def _worker(remote, parent_remote, env_fn_wrapper): elif cmd == 'render': remote.send(env.render(data)) elif cmd == 'close': + env.close() remote.close() break elif cmd == 'get_spaces': diff --git a/tests/test_save_load.py b/tests/test_save_load.py index fac954d..dec3ce6 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -145,7 +145,7 @@ def test_save_load_replay_buffer(model_class): model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=1000) model.learn(500) old_replay_buffer = deepcopy(model.replay_buffer) - model.save_replay_buffer(log_folder) + model.save_replay_buffer(replay_path) model.replay_buffer = None model.load_replay_buffer(replay_path)