mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Enable force_zip64 (#839)
* Enable force_zip64 * mark tests as expensive * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
30772aa9f5
commit
44e53ff811
6 changed files with 54 additions and 5 deletions
|
|
@ -4,6 +4,32 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.5.1a0 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Release 1.5.0 (2022-03-25)
|
||||
---------------------------
|
||||
|
||||
|
|
@ -931,4 +957,4 @@ And all the contributors:
|
|||
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
|
||||
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
|
||||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar
|
||||
@Gregwar @ycheng517
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
#!/bin/bash
|
||||
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes
|
||||
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive"
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ filterwarnings =
|
|||
ignore::UserWarning:gym
|
||||
ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning
|
||||
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
|
||||
markers =
|
||||
expensive: marks tests as expensive (deselect with '-m "not expensive"')
|
||||
|
||||
[pytype]
|
||||
inputs = stable_baselines3
|
||||
|
|
|
|||
|
|
@ -314,11 +314,11 @@ def save_to_zip_file(
|
|||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if pytorch_variables is not None:
|
||||
with archive.open("pytorch_variables.pth", mode="w") as pytorch_variables_file:
|
||||
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
|
||||
th.save(pytorch_variables, pytorch_variables_file)
|
||||
if params is not None:
|
||||
for file_name, dict_ in params.items():
|
||||
with archive.open(file_name + ".pth", mode="w") as param_file:
|
||||
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
|
||||
th.save(dict_, param_file)
|
||||
# Save metadata: library version when file was saved
|
||||
archive.writestr("_stable_baselines3_version", sb3.__version__)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.0
|
||||
1.5.1a0
|
||||
|
|
|
|||
|
|
@ -656,3 +656,24 @@ def test_open_file(tmp_path):
|
|||
with pytest.raises(ValueError):
|
||||
buff.close()
|
||||
open_path(buff, "w")
|
||||
|
||||
|
||||
@pytest.mark.expensive
|
||||
def test_save_load_large_model(tmp_path):
|
||||
"""
|
||||
Test saving and loading a model with a large policy that is greater than 2GB. We
|
||||
test only one algorithm since all algorithms share the same code for loading and
|
||||
saving the model.
|
||||
"""
|
||||
env = select_env(TD3)
|
||||
kwargs = dict(policy_kwargs=dict(net_arch=[8192, 8192, 8192]), device="cpu")
|
||||
model = TD3("MlpPolicy", env, **kwargs)
|
||||
|
||||
# test saving
|
||||
model.save(tmp_path / "test_save")
|
||||
|
||||
# test loading
|
||||
model = TD3.load(str(tmp_path / "test_save.zip"), env=env, **kwargs)
|
||||
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "test_save.zip")
|
||||
|
|
|
|||
Loading…
Reference in a new issue