mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Fix for HER with custom objects (#343)
This commit is contained in:
parent
c62e9259db
commit
237223f834
4 changed files with 22 additions and 4 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.0rc1 (WIP)
|
||||
Release 1.0rc2 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import io
|
||||
import pathlib
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -443,6 +443,7 @@ class HER(BaseAlgorithm):
|
|||
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
env: Optional[GymEnv] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
custom_objects: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> "BaseAlgorithm":
|
||||
"""
|
||||
|
|
@ -453,9 +454,15 @@ class HER(BaseAlgorithm):
|
|||
:param env: the new environment to run the loaded model on
|
||||
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
||||
:param device: Device on which the code should run.
|
||||
:param custom_objects: Dictionary of objects to replace
|
||||
upon loading. If a variable is present in this dictionary as a
|
||||
key, it will not be deserialized and the corresponding item
|
||||
will be used instead. Similar to custom_objects in
|
||||
``keras.models.load_model``. Useful when you have an object in
|
||||
file that can not be deserialized.
|
||||
:param kwargs: extra arguments to change the model when loading
|
||||
"""
|
||||
data, params, pytorch_variables = load_from_zip_file(path, device=device)
|
||||
data, params, pytorch_variables = load_from_zip_file(path, device=device, custom_objects=custom_objects)
|
||||
|
||||
# Remove stored device information and replace with ours
|
||||
if "policy_kwargs" in data:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.0rc1
|
||||
1.0rc2
|
||||
|
|
|
|||
|
|
@ -149,6 +149,17 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling):
|
|||
# Check
|
||||
model.save(tmp_path / "test_save.zip")
|
||||
del model
|
||||
|
||||
# test custom_objects
|
||||
# Load with custom objects
|
||||
custom_objects = dict(learning_rate=2e-5, dummy=1.0)
|
||||
model_ = HER.load(str(tmp_path / "test_save.zip"), env=env, custom_objects=custom_objects, verbose=2)
|
||||
assert model_.verbose == 2
|
||||
# Check that the custom object was taken into account
|
||||
assert model_.learning_rate == custom_objects["learning_rate"]
|
||||
# Check that only parameters that are here already are replaced
|
||||
assert not hasattr(model_, "dummy")
|
||||
|
||||
model = HER.load(str(tmp_path / "test_save.zip"), env=env)
|
||||
|
||||
# check if params are still the same after load
|
||||
|
|
|
|||
Loading…
Reference in a new issue