stable-baselines3/torchy_baselines/common/save_util.py
Dormann, Noah 1f0dd60b97 Fix saving on GPU - Loading on CPU (#45)
* removed policy from save, changed th.loads to map to device

* found hack: catch pickle exception and trying th.load with mapping instead, otherwise raise exception with more information -> loading cuda on cpu raises exception -> leads to th.load with map being called

* deleted todo

* updated changelog

* start of saving refactor

* first working c

* all tests pass, save refactored

* - backwards compatibilty not always
- make pytest all passing
- make typing all passing

* Fixes and simplify the save method

* Remove unused param

* Fix backward compat

* Fix docstring
2020-01-31 13:06:55 +01:00

170 lines
6.6 KiB
Python

"""
Save util taken from stable_baselines
used to serialize data (class parameters) of model classes
"""
import json
import base64
import functools
from typing import Dict, Any, Optional, Union
import torch as th
import cloudpickle
import warnings
def recursive_getattr(obj: Any, attr: str, *args) -> Any:
"""
Recursive version of getattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_getattr(MyObject, 'sub_object.name') # return test
:param obj: (Any)
:param attr: (str) Attribute to retrieve
:return: (Any) The attribute
"""
def _getattr(obj: Any, attr: str) -> Any:
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
"""
Recursive version of setattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_setattr(MyObject, 'sub_object.name', 'hello')
:param obj: (Any)
:param attr: (str) Attribute to set
:param val: (Any) New value of the attribute
"""
pre, _, post = attr.rpartition('.')
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
def is_json_serializable(item: Any) -> bool:
"""
Test if an object is serializable into JSON
:param item: (object) The object to be tested for JSON serialization.
:return: (bool) True if object is JSON serializable, false otherwise.
"""
# Try with try-except struct.
json_serializable = True
try:
_ = json.dumps(item)
except TypeError:
json_serializable = False
return json_serializable
def data_to_json(data: Dict[str, Any]) -> str:
"""
Turn data (class parameters) into a JSON string for storing
:param data: (Dict[str, Any]) Dictionary of class parameters to be
stored. Items that are not JSON serializable will be
pickled with Cloudpickle and stored as bytearray in
the JSON file
:return: (str) JSON string of the data serialized.
"""
# First, check what elements can not be JSONfied,
# and turn them into byte-strings
serializable_data = {}
for data_key, data_item in data.items():
# See if object is JSON serializable
if is_json_serializable(data_item):
# All good, store as it is
serializable_data[data_key] = data_item
else:
# Not serializable, cloudpickle it into
# bytes and convert to base64 string for storing.
# Also store type of the class for consumption
# from other languages/humans, so we have an
# idea what was being stored.
base64_encoded = base64.b64encode(
cloudpickle.dumps(data_item)
).decode()
# Use ":" to make sure we do
# not override these keys
# when we include variables of the object later
cloudpickle_serialization = {
":type:": str(type(data_item)),
":serialized:": base64_encoded
}
# Add first-level JSON-serializable items of the
# object for further details (but not deeper than this to
# avoid deep nesting).
# First we check that object has attributes (not all do,
# e.g. numpy scalars)
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
# Take elements from __dict__ for custom classes
item_generator = (
data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
)
for variable_name, variable_item in item_generator():
# Check if serializable. If not, just include the
# string-representation of the object.
if is_json_serializable(variable_item):
cloudpickle_serialization[variable_name] = variable_item
else:
cloudpickle_serialization[variable_name] = str(variable_item)
serializable_data[data_key] = cloudpickle_serialization
json_string = json.dumps(serializable_data, indent=4)
return json_string
def json_to_data(json_string: str,
device: Union[th.device, str] = 'cpu',
custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Turn JSON serialization of class-parameters back into dictionary.
:param json_string: (str) JSON serialization of the class-parameters
that should be loaded.
:param device: torch.device device to which the data should be mapped if errors occur
:param custom_objects: (dict) Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
`keras.models.load_model`. Useful when you have an object in
file that can not be deserialized.
:return: (dict) Loaded class parameters.
"""
if custom_objects is not None and not isinstance(custom_objects, dict):
raise ValueError("custom_objects argument must be a dict or None")
json_dict = json.loads(json_string)
# This will be filled with deserialized data
return_data = {}
for data_key, data_item in json_dict.items():
if custom_objects is not None and data_key in custom_objects.keys():
# If item is provided in custom_objects, replace
# the one from JSON with the one in custom_objects
return_data[data_key] = custom_objects[data_key]
elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
# If item is dictionary with ":serialized:"
# key, this means it is serialized with cloudpickle.
serialization = data_item[":serialized:"]
# Try-except deserialization in case we run into
# errors. If so, we can tell bit more information to
# user.
try:
base64_object = base64.b64decode(serialization.encode())
deserialized_object = cloudpickle.loads(base64_object)
except RuntimeError:
warnings.warn(f"Could not deserialize object {data_key}. " +
"Consider using `custom_objects` argument to replace " +
"this object.")
return_data[data_key] = deserialized_object
else:
# Read as it is
return_data[data_key] = data_item
return return_data