mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Allow model trained with python3.7 to be loaded with python3.8+ without the custom_objects workaround (#1123)
* Fix loading * Remove documentation note * Update changelog * Revert save_format change * Add test for errors while unpickling * Update version and cleanup Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
5ef10c8e69
commit
d5d1a02c15
8 changed files with 42 additions and 10 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.7.0a0 (WIP)
|
||||
Release 1.7.0a1 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -24,6 +24,7 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
|
||||
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
|
||||
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import collections
|
||||
import copy
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
|
|
|||
|
|
@ -162,13 +162,15 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No
|
|||
try:
|
||||
base64_object = base64.b64decode(serialization.encode())
|
||||
deserialized_object = cloudpickle.loads(base64_object)
|
||||
except (RuntimeError, TypeError):
|
||||
except (RuntimeError, TypeError, AttributeError) as e:
|
||||
warnings.warn(
|
||||
f"Could not deserialize object {data_key}. "
|
||||
+ "Consider using `custom_objects` argument to replace "
|
||||
+ "this object."
|
||||
"Consider using `custom_objects` argument to replace "
|
||||
"this object.\n"
|
||||
f"Exception: {e}"
|
||||
)
|
||||
return_data[data_key] = deserialized_object
|
||||
else:
|
||||
return_data[data_key] = deserialized_object
|
||||
else:
|
||||
# Read as it is
|
||||
return_data[data_key] = data_item
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import pickle
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a0
|
||||
1.7.0a1
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import warnings
|
||||
import zipfile
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
|
|
@ -690,3 +693,33 @@ def test_save_load_large_model(tmp_path):
|
|||
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "test_save.zip")
|
||||
|
||||
|
||||
def test_load_invalid_object(tmp_path):
|
||||
# See GH Issue #1122 for an example
|
||||
# of invalid object loading
|
||||
path = str(tmp_path / "ppo_pendulum.zip")
|
||||
PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0).save(path)
|
||||
|
||||
with zipfile.ZipFile(path, mode="r") as archive:
|
||||
json_data = json.loads(archive.read("data").decode())
|
||||
|
||||
# Intentionally corrupt the data
|
||||
serialization = json_data["learning_rate"][":serialized:"]
|
||||
base64_object = base64.b64decode(serialization.encode())
|
||||
new_bytes = base64_object.replace(b"CodeType", b"CodeTyps")
|
||||
base64_encoded = base64.b64encode(new_bytes).decode()
|
||||
json_data["learning_rate"][":serialized:"] = base64_encoded
|
||||
serialized_data = json.dumps(json_data, indent=4)
|
||||
|
||||
with open(tmp_path / "data", "w") as f:
|
||||
f.write(serialized_data)
|
||||
# Replace with the corrupted file
|
||||
# probably doesn't work on windows
|
||||
os.system(f"cd {tmp_path}; zip ppo_pendulum.zip data")
|
||||
with pytest.warns(UserWarning, match=r"custom_objects"):
|
||||
PPO.load(path)
|
||||
# Load with custom object, no warnings
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0))
|
||||
assert len(record) == 0
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import operator
|
||||
import warnings
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
|
|||
Loading…
Reference in a new issue