diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 24175bc..40e09ca 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.2.0a11 (WIP) +Release 2.2.0a12 (WIP) -------------------------- **Support for options at reset, bug fixes and better error messages** @@ -32,6 +32,7 @@ Bug Fixes: - Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD) - Fixed check_env for Sequence observation space (@corentinlger) - Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs) +- Fixed ResourceWarning when loading and saving models (files were not closed) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 3321585..40681b5 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -308,28 +308,28 @@ def save_to_zip_file( :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - save_path = open_path(save_path, "w", verbose=0, suffix="zip") - # data/params can be None, so do not - # try to serialize them blindly - if data is not None: - serialized_data = data_to_json(data) - - # Create a zip-archive and write our objects there. - with zipfile.ZipFile(save_path, mode="w") as archive: - # Do not try to save "None" elements + with open_path(save_path, "w", verbose=0, suffix="zip") as save_path: + # data/params can be None, so do not + # try to serialize them blindly if data is not None: - archive.writestr("data", serialized_data) - if pytorch_variables is not None: - with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file: - th.save(pytorch_variables, pytorch_variables_file) - if params is not None: - for file_name, dict_ in params.items(): - with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file: - th.save(dict_, param_file) - # Save metadata: library version when file was saved - archive.writestr("_stable_baselines3_version", sb3.__version__) - # Save system info about the current python env - archive.writestr("system_info.txt", get_system_info(print_info=False)[1]) + serialized_data = data_to_json(data) + + # Create a zip-archive and write our objects there. + with zipfile.ZipFile(save_path, mode="w") as archive: + # Do not try to save "None" elements + if data is not None: + archive.writestr("data", serialized_data) + if pytorch_variables is not None: + with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file: + th.save(pytorch_variables, pytorch_variables_file) + if params is not None: + for file_name, dict_ in params.items(): + with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file: + th.save(dict_, param_file) + # Save metadata: library version when file was saved + archive.writestr("_stable_baselines3_version", sb3.__version__) + # Save system info about the current python env + archive.writestr("system_info.txt", get_system_info(print_info=False)[1]) def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None: @@ -450,4 +450,6 @@ def load_from_zip_file( except zipfile.BadZipFile as e: # load_path wasn't a zip file raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e + finally: + load_path.close() return data, params, pytorch_variables diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 8375884..c598c73 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -128,14 +128,14 @@ class SubprocVecEnv(VecEnv): def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False - obs, rews, dones, infos, self.reset_infos = zip(*results) - return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos + obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment] + return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] def reset(self) -> VecEnvObs: for env_idx, remote in enumerate(self.remotes): remote.send(("reset", (self._seeds[env_idx], self._options[env_idx]))) results = [remote.recv() for remote in self.remotes] - obs, self.reset_infos = zip(*results) + obs, self.reset_infos = zip(*results) # type: ignore[assignment] # Seeds and options are only used once self._reset_seeds() self._reset_options() diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 13ce6d7..5740e0c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0a11 +2.2.0a12 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index b574d74..778d944 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -747,3 +747,12 @@ def test_dqn_target_update_interval(tmp_path): model = DQN.load(tmp_path / "dqn_cartpole") os.remove(tmp_path / "dqn_cartpole.zip") assert model.target_update_interval == 100 + + +# Turn warnings into errors +@pytest.mark.filterwarnings("error") +def test_no_resource_warning(tmp_path): + # check that files are properly closed + # Create a PPO agent and save it + PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole") + PPO.load(tmp_path / "dqn_cartpole")