Allow model trained with python3.7 to be loaded with python3.8+ without the custom_objects workaround (#1123)

* Fix loading

* Remove documentation note

* Update changelog

* Revert save_format change

* Add test for errors while unpickling

* Update version and cleanup

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-10-17 17:33:47 +02:00 committed by GitHub
parent 5ef10c8e69
commit d5d1a02c15
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 42 additions and 10 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.7.0a0 (WIP)
Release 1.7.0a1 (WIP)
--------------------------
Breaking Changes:
@ -24,6 +24,7 @@ Bug Fixes:
^^^^^^^^^^
- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
Deprecations:
^^^^^^^^^^^^^

View file

@ -2,7 +2,6 @@
import collections
import copy
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

View file

@ -162,13 +162,15 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No
try:
base64_object = base64.b64decode(serialization.encode())
deserialized_object = cloudpickle.loads(base64_object)
except (RuntimeError, TypeError):
except (RuntimeError, TypeError, AttributeError) as e:
warnings.warn(
f"Could not deserialize object {data_key}. "
+ "Consider using `custom_objects` argument to replace "
+ "this object."
"Consider using `custom_objects` argument to replace "
"this object.\n"
f"Exception: {e}"
)
return_data[data_key] = deserialized_object
else:
return_data[data_key] = deserialized_object
else:
# Read as it is
return_data[data_key] = data_item

View file

@ -1,5 +1,4 @@
import pickle
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

View file

@ -1,4 +1,3 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym

View file

@ -1 +1 @@
1.7.0a0
1.7.0a1

View file

@ -1,7 +1,10 @@
import base64
import io
import json
import os
import pathlib
import warnings
import zipfile
from collections import OrderedDict
from copy import deepcopy
@ -690,3 +693,33 @@ def test_save_load_large_model(tmp_path):
# clear file from os
os.remove(tmp_path / "test_save.zip")
def test_load_invalid_object(tmp_path):
# See GH Issue #1122 for an example
# of invalid object loading
path = str(tmp_path / "ppo_pendulum.zip")
PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0).save(path)
with zipfile.ZipFile(path, mode="r") as archive:
json_data = json.loads(archive.read("data").decode())
# Intentionally corrupt the data
serialization = json_data["learning_rate"][":serialized:"]
base64_object = base64.b64decode(serialization.encode())
new_bytes = base64_object.replace(b"CodeType", b"CodeTyps")
base64_encoded = base64.b64encode(new_bytes).decode()
json_data["learning_rate"][":serialized:"] = base64_encoded
serialized_data = json.dumps(json_data, indent=4)
with open(tmp_path / "data", "w") as f:
f.write(serialized_data)
# Replace with the corrupted file
# probably doesn't work on windows
os.system(f"cd {tmp_path}; zip ppo_pendulum.zip data")
with pytest.warns(UserWarning, match=r"custom_objects"):
PPO.load(path)
# Load with custom object, no warnings
with warnings.catch_warnings(record=True) as record:
PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0))
assert len(record) == 0

View file

@ -1,5 +1,4 @@
import operator
import warnings
import gym
import numpy as np