diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index ed9834f..e11c5c9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.7.0a0 (WIP) +Release 1.7.0a1 (WIP) -------------------------- Breaking Changes: @@ -24,6 +24,7 @@ Bug Fixes: ^^^^^^^^^^ - Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde) - Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm`` +- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 876979f..5972fa3 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -2,7 +2,6 @@ import collections import copy -import warnings from abc import ABC, abstractmethod from functools import partial from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 90392df..facc55a 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -162,13 +162,15 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No try: base64_object = base64.b64decode(serialization.encode()) deserialized_object = cloudpickle.loads(base64_object) - except (RuntimeError, TypeError): + except (RuntimeError, TypeError, AttributeError) as e: warnings.warn( f"Could not deserialize object {data_key}. " - + "Consider using `custom_objects` argument to replace " - + "this object." + "Consider using `custom_objects` argument to replace " + "this object.\n" + f"Exception: {e}" ) - return_data[data_key] = deserialized_object + else: + return_data[data_key] = deserialized_object else: # Read as it is return_data[data_key] = data_item diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 53e94af..73c890c 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -1,5 +1,4 @@ import pickle -import warnings from copy import deepcopy from typing import Any, Dict, List, Optional, Union diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 368cd67..ac93249 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -1,4 +1,3 @@ -import warnings from typing import Any, Dict, List, Optional, Tuple, Type, Union import gym diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 56fee06..12cd5fb 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.7.0a0 +1.7.0a1 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 988d432..91b0760 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,7 +1,10 @@ +import base64 import io +import json import os import pathlib import warnings +import zipfile from collections import OrderedDict from copy import deepcopy @@ -690,3 +693,33 @@ def test_save_load_large_model(tmp_path): # clear file from os os.remove(tmp_path / "test_save.zip") + + +def test_load_invalid_object(tmp_path): + # See GH Issue #1122 for an example + # of invalid object loading + path = str(tmp_path / "ppo_pendulum.zip") + PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0).save(path) + + with zipfile.ZipFile(path, mode="r") as archive: + json_data = json.loads(archive.read("data").decode()) + + # Intentionally corrupt the data + serialization = json_data["learning_rate"][":serialized:"] + base64_object = base64.b64decode(serialization.encode()) + new_bytes = base64_object.replace(b"CodeType", b"CodeTyps") + base64_encoded = base64.b64encode(new_bytes).decode() + json_data["learning_rate"][":serialized:"] = base64_encoded + serialized_data = json.dumps(json_data, indent=4) + + with open(tmp_path / "data", "w") as f: + f.write(serialized_data) + # Replace with the corrupted file + # probably doesn't work on windows + os.system(f"cd {tmp_path}; zip ppo_pendulum.zip data") + with pytest.warns(UserWarning, match=r"custom_objects"): + PPO.load(path) + # Load with custom object, no warnings + with warnings.catch_warnings(record=True) as record: + PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0)) + assert len(record) == 0 diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 1fbf5b7..07c720f 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,5 +1,4 @@ import operator -import warnings import gym import numpy as np