diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 54f8cc1..74abf17 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,12 +3,15 @@ Changelog ========== -Pre-Release 0.9.0a1 (WIP) +Pre-Release 0.9.0a2 (WIP) ------------------------------ Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed ``device`` keyword argument of policies; use ``policy.to(device)`` instead. (@qxcv) +- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and + ``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params`` +- Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity New Features: ^^^^^^^^^^^^^ @@ -35,12 +38,9 @@ Others: - Fix type annotation of ``make_vec_env`` (@ManifoldFR) - Removed ``AlreadySteppingError`` and ``NotSteppingError`` that were not used - Fixed typos in SAC and TD3 -- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and - ``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params`` - Reorganized functions for clarity in ``BaseClass`` (save/load functions close to each other, private functions at top) - Clarified docstrings on what is saved and loaded to/from files -- Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity - Simplified ``save_to_zip_file`` function by removing duplicate code - Store library version along with the saved models diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 7260f11..3f86269 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -404,7 +404,8 @@ def load_from_zip_file( # Load the parameters with the right ``map_location``. # Remove ".pth" ending with splitext th_object = th.load(file_content, map_location=device) - if file_path == "pytorch_variables.pth": + # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138 + if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": # PyTorch variables (not state_dicts) pytorch_variables = th_object else: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 8d8f786..63aaf3d 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.9.0a1 +0.9.0a2