mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-29 23:07:07 +00:00
Fix resource warning (#1742)
* Fix resource warning * Add test and update changelog * Fix for new mypy version
This commit is contained in:
parent
b413f4c285
commit
23fbeb5975
5 changed files with 38 additions and 26 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.2.0a11 (WIP)
|
||||
Release 2.2.0a12 (WIP)
|
||||
--------------------------
|
||||
**Support for options at reset, bug fixes and better error messages**
|
||||
|
||||
|
|
@ -32,6 +32,7 @@ Bug Fixes:
|
|||
- Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD)
|
||||
- Fixed check_env for Sequence observation space (@corentinlger)
|
||||
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)
|
||||
- Fixed ResourceWarning when loading and saving models (files were not closed)
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -308,28 +308,28 @@ def save_to_zip_file(
|
|||
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
||||
"""
|
||||
save_path = open_path(save_path, "w", verbose=0, suffix="zip")
|
||||
# data/params can be None, so do not
|
||||
# try to serialize them blindly
|
||||
if data is not None:
|
||||
serialized_data = data_to_json(data)
|
||||
|
||||
# Create a zip-archive and write our objects there.
|
||||
with zipfile.ZipFile(save_path, mode="w") as archive:
|
||||
# Do not try to save "None" elements
|
||||
with open_path(save_path, "w", verbose=0, suffix="zip") as save_path:
|
||||
# data/params can be None, so do not
|
||||
# try to serialize them blindly
|
||||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if pytorch_variables is not None:
|
||||
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
|
||||
th.save(pytorch_variables, pytorch_variables_file)
|
||||
if params is not None:
|
||||
for file_name, dict_ in params.items():
|
||||
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
|
||||
th.save(dict_, param_file)
|
||||
# Save metadata: library version when file was saved
|
||||
archive.writestr("_stable_baselines3_version", sb3.__version__)
|
||||
# Save system info about the current python env
|
||||
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
|
||||
serialized_data = data_to_json(data)
|
||||
|
||||
# Create a zip-archive and write our objects there.
|
||||
with zipfile.ZipFile(save_path, mode="w") as archive:
|
||||
# Do not try to save "None" elements
|
||||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if pytorch_variables is not None:
|
||||
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
|
||||
th.save(pytorch_variables, pytorch_variables_file)
|
||||
if params is not None:
|
||||
for file_name, dict_ in params.items():
|
||||
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
|
||||
th.save(dict_, param_file)
|
||||
# Save metadata: library version when file was saved
|
||||
archive.writestr("_stable_baselines3_version", sb3.__version__)
|
||||
# Save system info about the current python env
|
||||
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
|
||||
|
||||
|
||||
def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
|
||||
|
|
@ -450,4 +450,6 @@ def load_from_zip_file(
|
|||
except zipfile.BadZipFile as e:
|
||||
# load_path wasn't a zip file
|
||||
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
|
||||
finally:
|
||||
load_path.close()
|
||||
return data, params, pytorch_variables
|
||||
|
|
|
|||
|
|
@ -128,14 +128,14 @@ class SubprocVecEnv(VecEnv):
|
|||
def step_wait(self) -> VecEnvStepReturn:
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
self.waiting = False
|
||||
obs, rews, dones, infos, self.reset_infos = zip(*results)
|
||||
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
|
||||
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
|
||||
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
for env_idx, remote in enumerate(self.remotes):
|
||||
remote.send(("reset", (self._seeds[env_idx], self._options[env_idx])))
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
obs, self.reset_infos = zip(*results)
|
||||
obs, self.reset_infos = zip(*results) # type: ignore[assignment]
|
||||
# Seeds and options are only used once
|
||||
self._reset_seeds()
|
||||
self._reset_options()
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0a11
|
||||
2.2.0a12
|
||||
|
|
|
|||
|
|
@ -747,3 +747,12 @@ def test_dqn_target_update_interval(tmp_path):
|
|||
model = DQN.load(tmp_path / "dqn_cartpole")
|
||||
os.remove(tmp_path / "dqn_cartpole.zip")
|
||||
assert model.target_update_interval == 100
|
||||
|
||||
|
||||
# Turn warnings into errors
|
||||
@pytest.mark.filterwarnings("error")
|
||||
def test_no_resource_warning(tmp_path):
|
||||
# check that files are properly closed
|
||||
# Create a PPO agent and save it
|
||||
PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole")
|
||||
PPO.load(tmp_path / "dqn_cartpole")
|
||||
|
|
|
|||
Loading…
Reference in a new issue