diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c3e3253..c2be32a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.0rc1 (WIP) +Release 1.0rc2 (WIP) ------------------------------- Breaking Changes: diff --git a/stable_baselines3/her/her.py b/stable_baselines3/her/her.py index 642986e..43984de 100644 --- a/stable_baselines3/her/her.py +++ b/stable_baselines3/her/her.py @@ -1,7 +1,7 @@ import io import pathlib import warnings -from typing import Any, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union import numpy as np import torch as th @@ -443,6 +443,7 @@ class HER(BaseAlgorithm): path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", + custom_objects: Optional[Dict[str, Any]] = None, **kwargs, ) -> "BaseAlgorithm": """ @@ -453,9 +454,15 @@ class HER(BaseAlgorithm): :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. + :param custom_objects: Dictionary of objects to replace + upon loading. If a variable is present in this dictionary as a + key, it will not be deserialized and the corresponding item + will be used instead. Similar to custom_objects in + ``keras.models.load_model``. Useful when you have an object in + file that can not be deserialized. :param kwargs: extra arguments to change the model when loading """ - data, params, pytorch_variables = load_from_zip_file(path, device=device) + data, params, pytorch_variables = load_from_zip_file(path, device=device, custom_objects=custom_objects) # Remove stored device information and replace with ours if "policy_kwargs" in data: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 0f82de4..db805d3 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.0rc1 +1.0rc2 diff --git a/tests/test_her.py b/tests/test_her.py index a31ab60..5d76d17 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -149,6 +149,17 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling): # Check model.save(tmp_path / "test_save.zip") del model + + # test custom_objects + # Load with custom objects + custom_objects = dict(learning_rate=2e-5, dummy=1.0) + model_ = HER.load(str(tmp_path / "test_save.zip"), env=env, custom_objects=custom_objects, verbose=2) + assert model_.verbose == 2 + # Check that the custom object was taken into account + assert model_.learning_rate == custom_objects["learning_rate"] + # Check that only parameters that are here already are replaced + assert not hasattr(model_, "dummy") + model = HER.load(str(tmp_path / "test_save.zip"), env=env) # check if params are still the same after load