Enable force_zip64 (#839)

* Enable force_zip64

* mark tests as expensive

* Update changelog

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Yifei Cheng 2022-03-28 04:35:33 -04:00 committed by GitHub
parent 30772aa9f5
commit 44e53ff811
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 5 deletions

View file

@ -4,6 +4,32 @@ Changelog
==========
Release 1.5.1a0 (WIP)
---------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
SB3-Contrib
^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
Release 1.5.0 (2022-03-25)
---------------------------
@ -931,4 +957,4 @@ And all the contributors:
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar
@Gregwar @ycheng517

View file

@ -1,2 +1,2 @@
#!/bin/bash
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive"

View file

@ -15,6 +15,8 @@ filterwarnings =
ignore::UserWarning:gym
ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
markers =
expensive: marks tests as expensive (deselect with '-m "not expensive"')
[pytype]
inputs = stable_baselines3

View file

@ -314,11 +314,11 @@ def save_to_zip_file(
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w") as pytorch_variables_file:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w") as param_file:
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)

View file

@ -1 +1 @@
1.5.0
1.5.1a0

View file

@ -656,3 +656,24 @@ def test_open_file(tmp_path):
with pytest.raises(ValueError):
buff.close()
open_path(buff, "w")
@pytest.mark.expensive
def test_save_load_large_model(tmp_path):
"""
Test saving and loading a model with a large policy that is greater than 2GB. We
test only one algorithm since all algorithms share the same code for loading and
saving the model.
"""
env = select_env(TD3)
kwargs = dict(policy_kwargs=dict(net_arch=[8192, 8192, 8192]), device="cpu")
model = TD3("MlpPolicy", env, **kwargs)
# test saving
model.save(tmp_path / "test_save")
# test loading
model = TD3.load(str(tmp_path / "test_save.zip"), env=env, **kwargs)
# clear file from os
os.remove(tmp_path / "test_save.zip")