Release 2.2.1: Hotfix file closing (#1754)

* new closing policy

* revert #1742

* Add tests and update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2023-11-17 23:50:23 +01:00 committed by GitHub
parent e1eac844af
commit e3dea4b2e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 33 deletions

View file

@ -3,10 +3,16 @@
Changelog
==========
Release 2.2.0 (2023-11-16)
Release 2.2.1 (2023-11-17)
--------------------------
**Support for options at reset, bug fixes and better error messages**
.. note::
SB3 v2.2.0 was yanked after a breaking change was found in `GH#1751 <https://github.com/DLR-RM/stable-baselines3/issues/1751>`_.
Please use SB3 v2.2.1 and not v2.2.0.
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
@ -32,7 +38,9 @@ Bug Fixes:
- Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD)
- Fixed check_env for Sequence observation space (@corentinlger)
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)
- Fixed ResourceWarning when loading and saving models (files were not closed)
- Fixed ResourceWarning when loading and saving models (files were not closed), please note that only path are closed automatically,
the behavior stay the same for tempfiles (they need to be closed manually),
the behavior is now consistent when loading/saving replay buffer
`SB3-Contrib`_
^^^^^^^^^^^^^^
@ -76,6 +84,7 @@ Others:
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed ``stable_baselines3/common/policies.py`` type hints
- Switched to ``mypy`` only for checking types
- Added tests to check consistency when saving/loading files
Documentation:
^^^^^^^^^^^^^^

View file

@ -308,28 +308,31 @@ def save_to_zip_file(
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(save_path, "w", verbose=0, suffix="zip") as save_path:
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)
file = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)
# Create a zip-archive and write our objects there.
with zipfile.ZipFile(save_path, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
# Create a zip-archive and write our objects there.
with zipfile.ZipFile(file, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
if isinstance(save_path, (str, pathlib.Path)):
file.close()
def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
@ -344,10 +347,12 @@ def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, ver
:param obj: The object to save.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file_handler, protocol=pickle.HIGHEST_PROTOCOL)
file = open_path(path, "w", verbose=verbose, suffix="pkl")
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
if isinstance(path, (str, pathlib.Path)):
file.close()
def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
@ -360,8 +365,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
path actually exists. If path is a io.BufferedIOBase the path exists.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
return pickle.load(file_handler)
file = open_path(path, "r", verbose=verbose, suffix="pkl")
obj = pickle.load(file)
if isinstance(path, (str, pathlib.Path)):
file.close()
return obj
def load_from_zip_file(
@ -391,14 +399,14 @@ def load_from_zip_file(
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
file = open_path(load_path, "r", verbose=verbose, suffix="zip")
# set device to cpu if cuda is not available
device = get_device(device=device)
# Open the zip archive and load data
try:
with zipfile.ZipFile(load_path) as archive:
with zipfile.ZipFile(file) as archive:
namelist = archive.namelist()
# If data or parameters is not in the
# zip archive, assume they were stored
@ -451,5 +459,6 @@ def load_from_zip_file(
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
finally:
load_path.close()
if isinstance(load_path, (str, pathlib.Path)):
file.close()
return data, params, pytorch_variables

View file

@ -1 +1 @@
2.2.0
2.2.1

View file

@ -3,6 +3,7 @@ import io
import json
import os
import pathlib
import tempfile
import warnings
import zipfile
from collections import OrderedDict
@ -752,7 +753,33 @@ def test_dqn_target_update_interval(tmp_path):
# Turn warnings into errors
@pytest.mark.filterwarnings("error")
def test_no_resource_warning(tmp_path):
# Check behavior of save/load
# see https://github.com/DLR-RM/stable-baselines3/issues/1751
# check that files are properly closed
# Create a PPO agent and save it
PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole")
PPO.load(tmp_path / "dqn_cartpole")
PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole"))
PPO.load(str(tmp_path / "dqn_cartpole"))
# Do the same but in memory, should not close the file
with tempfile.TemporaryFile() as fp:
PPO("MlpPolicy", "CartPole-v1").save(fp)
PPO.load(fp)
assert not fp.closed
# Same but with replay buffer
model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200)
model.save_replay_buffer(tmp_path / "replay")
model.load_replay_buffer(tmp_path / "replay")
model.save_replay_buffer(str(tmp_path / "replay"))
model.load_replay_buffer(str(tmp_path / "replay"))
with tempfile.TemporaryFile() as fp:
model.save_replay_buffer(fp)
fp.seek(0)
model.load_replay_buffer(fp)
assert not fp.closed