mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-02 03:55:39 +00:00
Release 2.2.1: Hotfix file closing (#1754)
* new closing policy * revert #1742 * Add tests and update changelog --------- Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
e1eac844af
commit
e3dea4b2e0
4 changed files with 78 additions and 33 deletions
|
|
@ -3,10 +3,16 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.2.0 (2023-11-16)
|
||||
Release 2.2.1 (2023-11-17)
|
||||
--------------------------
|
||||
**Support for options at reset, bug fixes and better error messages**
|
||||
|
||||
.. note::
|
||||
|
||||
SB3 v2.2.0 was yanked after a breaking change was found in `GH#1751 <https://github.com/DLR-RM/stable-baselines3/issues/1751>`_.
|
||||
Please use SB3 v2.2.1 and not v2.2.0.
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
|
||||
|
|
@ -32,7 +38,9 @@ Bug Fixes:
|
|||
- Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD)
|
||||
- Fixed check_env for Sequence observation space (@corentinlger)
|
||||
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)
|
||||
- Fixed ResourceWarning when loading and saving models (files were not closed)
|
||||
- Fixed ResourceWarning when loading and saving models (files were not closed), please note that only path are closed automatically,
|
||||
the behavior stay the same for tempfiles (they need to be closed manually),
|
||||
the behavior is now consistent when loading/saving replay buffer
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
@ -76,6 +84,7 @@ Others:
|
|||
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
|
||||
- Fixed ``stable_baselines3/common/policies.py`` type hints
|
||||
- Switched to ``mypy`` only for checking types
|
||||
- Added tests to check consistency when saving/loading files
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -308,28 +308,31 @@ def save_to_zip_file(
|
|||
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
||||
"""
|
||||
with open_path(save_path, "w", verbose=0, suffix="zip") as save_path:
|
||||
# data/params can be None, so do not
|
||||
# try to serialize them blindly
|
||||
if data is not None:
|
||||
serialized_data = data_to_json(data)
|
||||
file = open_path(save_path, "w", verbose=0, suffix="zip")
|
||||
# data/params can be None, so do not
|
||||
# try to serialize them blindly
|
||||
if data is not None:
|
||||
serialized_data = data_to_json(data)
|
||||
|
||||
# Create a zip-archive and write our objects there.
|
||||
with zipfile.ZipFile(save_path, mode="w") as archive:
|
||||
# Do not try to save "None" elements
|
||||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if pytorch_variables is not None:
|
||||
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
|
||||
th.save(pytorch_variables, pytorch_variables_file)
|
||||
if params is not None:
|
||||
for file_name, dict_ in params.items():
|
||||
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
|
||||
th.save(dict_, param_file)
|
||||
# Save metadata: library version when file was saved
|
||||
archive.writestr("_stable_baselines3_version", sb3.__version__)
|
||||
# Save system info about the current python env
|
||||
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
|
||||
# Create a zip-archive and write our objects there.
|
||||
with zipfile.ZipFile(file, mode="w") as archive:
|
||||
# Do not try to save "None" elements
|
||||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if pytorch_variables is not None:
|
||||
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
|
||||
th.save(pytorch_variables, pytorch_variables_file)
|
||||
if params is not None:
|
||||
for file_name, dict_ in params.items():
|
||||
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
|
||||
th.save(dict_, param_file)
|
||||
# Save metadata: library version when file was saved
|
||||
archive.writestr("_stable_baselines3_version", sb3.__version__)
|
||||
# Save system info about the current python env
|
||||
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
|
||||
|
||||
if isinstance(save_path, (str, pathlib.Path)):
|
||||
file.close()
|
||||
|
||||
|
||||
def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
|
||||
|
|
@ -344,10 +347,12 @@ def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, ver
|
|||
:param obj: The object to save.
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
||||
"""
|
||||
with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
|
||||
# Use protocol>=4 to support saving replay buffers >= 4Gb
|
||||
# See https://docs.python.org/3/library/pickle.html
|
||||
pickle.dump(obj, file_handler, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
file = open_path(path, "w", verbose=verbose, suffix="pkl")
|
||||
# Use protocol>=4 to support saving replay buffers >= 4Gb
|
||||
# See https://docs.python.org/3/library/pickle.html
|
||||
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if isinstance(path, (str, pathlib.Path)):
|
||||
file.close()
|
||||
|
||||
|
||||
def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
|
||||
|
|
@ -360,8 +365,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
|
|||
path actually exists. If path is a io.BufferedIOBase the path exists.
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
||||
"""
|
||||
with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
|
||||
return pickle.load(file_handler)
|
||||
file = open_path(path, "r", verbose=verbose, suffix="pkl")
|
||||
obj = pickle.load(file)
|
||||
if isinstance(path, (str, pathlib.Path)):
|
||||
file.close()
|
||||
return obj
|
||||
|
||||
|
||||
def load_from_zip_file(
|
||||
|
|
@ -391,14 +399,14 @@ def load_from_zip_file(
|
|||
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
|
||||
and dict of pytorch variables
|
||||
"""
|
||||
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
|
||||
file = open_path(load_path, "r", verbose=verbose, suffix="zip")
|
||||
|
||||
# set device to cpu if cuda is not available
|
||||
device = get_device(device=device)
|
||||
|
||||
# Open the zip archive and load data
|
||||
try:
|
||||
with zipfile.ZipFile(load_path) as archive:
|
||||
with zipfile.ZipFile(file) as archive:
|
||||
namelist = archive.namelist()
|
||||
# If data or parameters is not in the
|
||||
# zip archive, assume they were stored
|
||||
|
|
@ -451,5 +459,6 @@ def load_from_zip_file(
|
|||
# load_path wasn't a zip file
|
||||
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
|
||||
finally:
|
||||
load_path.close()
|
||||
if isinstance(load_path, (str, pathlib.Path)):
|
||||
file.close()
|
||||
return data, params, pytorch_variables
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0
|
||||
2.2.1
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import io
|
|||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import warnings
|
||||
import zipfile
|
||||
from collections import OrderedDict
|
||||
|
|
@ -752,7 +753,33 @@ def test_dqn_target_update_interval(tmp_path):
|
|||
# Turn warnings into errors
|
||||
@pytest.mark.filterwarnings("error")
|
||||
def test_no_resource_warning(tmp_path):
|
||||
# Check behavior of save/load
|
||||
# see https://github.com/DLR-RM/stable-baselines3/issues/1751
|
||||
|
||||
# check that files are properly closed
|
||||
# Create a PPO agent and save it
|
||||
PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole")
|
||||
PPO.load(tmp_path / "dqn_cartpole")
|
||||
|
||||
PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole"))
|
||||
PPO.load(str(tmp_path / "dqn_cartpole"))
|
||||
|
||||
# Do the same but in memory, should not close the file
|
||||
with tempfile.TemporaryFile() as fp:
|
||||
PPO("MlpPolicy", "CartPole-v1").save(fp)
|
||||
PPO.load(fp)
|
||||
assert not fp.closed
|
||||
|
||||
# Same but with replay buffer
|
||||
model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200)
|
||||
model.save_replay_buffer(tmp_path / "replay")
|
||||
model.load_replay_buffer(tmp_path / "replay")
|
||||
|
||||
model.save_replay_buffer(str(tmp_path / "replay"))
|
||||
model.load_replay_buffer(str(tmp_path / "replay"))
|
||||
|
||||
with tempfile.TemporaryFile() as fp:
|
||||
model.save_replay_buffer(fp)
|
||||
fp.seek(0)
|
||||
model.load_replay_buffer(fp)
|
||||
assert not fp.closed
|
||||
|
|
|
|||
Loading…
Reference in a new issue