Fix for HER with custom objects (#343)

This commit is contained in:
Antonin RAFFIN 2021-03-06 15:57:27 +01:00 committed by GitHub
parent c62e9259db
commit 237223f834
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 4 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 1.0rc1 (WIP)
Release 1.0rc2 (WIP)
-------------------------------
Breaking Changes:

View file

@ -1,7 +1,7 @@
import io
import pathlib
import warnings
from typing import Any, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
import numpy as np
import torch as th
@ -443,6 +443,7 @@ class HER(BaseAlgorithm):
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
custom_objects: Optional[Dict[str, Any]] = None,
**kwargs,
) -> "BaseAlgorithm":
"""
@ -453,9 +454,15 @@ class HER(BaseAlgorithm):
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: Device on which the code should run.
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param kwargs: extra arguments to change the model when loading
"""
data, params, pytorch_variables = load_from_zip_file(path, device=device)
data, params, pytorch_variables = load_from_zip_file(path, device=device, custom_objects=custom_objects)
# Remove stored device information and replace with ours
if "policy_kwargs" in data:

View file

@ -1 +1 @@
1.0rc1
1.0rc2

View file

@ -149,6 +149,17 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling):
# Check
model.save(tmp_path / "test_save.zip")
del model
# test custom_objects
# Load with custom objects
custom_objects = dict(learning_rate=2e-5, dummy=1.0)
model_ = HER.load(str(tmp_path / "test_save.zip"), env=env, custom_objects=custom_objects, verbose=2)
assert model_.verbose == 2
# Check that the custom object was taken into account
assert model_.learning_rate == custom_objects["learning_rate"]
# Check that only parameters that are here already are replaced
assert not hasattr(model_, "dummy")
model = HER.load(str(tmp_path / "test_save.zip"), env=env)
# check if params are still the same after load