diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fcbfb66..e17c3df 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,6 +4,32 @@ Changelog ========== +Release 1.5.1a0 (WIP) +--------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ + +SB3-Contrib +^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ +- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + + Release 1.5.0 (2022-03-25) --------------------------- @@ -931,4 +957,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar +@Gregwar @ycheng517 diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index a379507..9ecb207 100755 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -1,2 +1,2 @@ #!/bin/bash -python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes +python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive" diff --git a/setup.cfg b/setup.cfg index e23ad45..5bc66c2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,6 +15,8 @@ filterwarnings = ignore::UserWarning:gym ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning +markers = + expensive: marks tests as expensive (deselect with '-m "not expensive"') [pytype] inputs = stable_baselines3 diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index dcacfba..e0b104f 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -314,11 +314,11 @@ def save_to_zip_file( if data is not None: archive.writestr("data", serialized_data) if pytorch_variables is not None: - with archive.open("pytorch_variables.pth", mode="w") as pytorch_variables_file: + with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file: th.save(pytorch_variables, pytorch_variables_file) if params is not None: for file_name, dict_ in params.items(): - with archive.open(file_name + ".pth", mode="w") as param_file: + with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file: th.save(dict_, param_file) # Save metadata: library version when file was saved archive.writestr("_stable_baselines3_version", sb3.__version__) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index bc80560..33271c4 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.0 +1.5.1a0 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 19b2c90..452e6fb 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -656,3 +656,24 @@ def test_open_file(tmp_path): with pytest.raises(ValueError): buff.close() open_path(buff, "w") + + +@pytest.mark.expensive +def test_save_load_large_model(tmp_path): + """ + Test saving and loading a model with a large policy that is greater than 2GB. We + test only one algorithm since all algorithms share the same code for loading and + saving the model. + """ + env = select_env(TD3) + kwargs = dict(policy_kwargs=dict(net_arch=[8192, 8192, 8192]), device="cpu") + model = TD3("MlpPolicy", env, **kwargs) + + # test saving + model.save(tmp_path / "test_save") + + # test loading + model = TD3.load(str(tmp_path / "test_save.zip"), env=env, **kwargs) + + # clear file from os + os.remove(tmp_path / "test_save.zip")