mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-28 03:21:16 +00:00
Merge branch 'master' into sde
This commit is contained in:
commit
c76c657f2d
4 changed files with 7 additions and 3 deletions
|
|
@ -8,12 +8,14 @@ Pre-Release 0.8.0a0 (WIP)
|
|||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug in the ``close()`` method of ``SubprocVecEnv``, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -323,3 +325,4 @@ And all the contributors:
|
|||
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
|
||||
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
|
||||
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
|
||||
@tirafesi
|
||||
|
|
|
|||
|
|
@ -106,10 +106,10 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
"""
|
||||
Save the replay buffer as a pickle file.
|
||||
|
||||
:param path: (str) Path to a log folder
|
||||
:param path: (str) Path to the file where the replay buffer should be saved
|
||||
"""
|
||||
assert self.replay_buffer is not None, "The replay buffer is not defined"
|
||||
with open(os.path.join(path, 'replay_buffer.pkl'), 'wb') as file_handler:
|
||||
with open(path, 'wb') as file_handler:
|
||||
pickle.dump(self.replay_buffer, file_handler)
|
||||
|
||||
def load_replay_buffer(self, path: str):
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ def _worker(remote, parent_remote, env_fn_wrapper):
|
|||
elif cmd == 'render':
|
||||
remote.send(env.render(data))
|
||||
elif cmd == 'close':
|
||||
env.close()
|
||||
remote.close()
|
||||
break
|
||||
elif cmd == 'get_spaces':
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ def test_save_load_replay_buffer(model_class):
|
|||
model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=1000)
|
||||
model.learn(500)
|
||||
old_replay_buffer = deepcopy(model.replay_buffer)
|
||||
model.save_replay_buffer(log_folder)
|
||||
model.save_replay_buffer(replay_path)
|
||||
model.replay_buffer = None
|
||||
model.load_replay_buffer(replay_path)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue