""" Save util taken from stable_baselines used to serialize data (class parameters) of model classes """ import json import base64 import functools from typing import Dict, Any, Optional import cloudpickle import warnings def recursive_getattr(obj: Any, attr: str, *args) -> Any: """ Recursive version of getattr taken from https://stackoverflow.com/questions/31174295 Ex: > MyObject.sub_object = SubObject(name='test') > recursive_getattr(MyObject, 'sub_object.name') # return test :param obj: (Any) :param attr: (str) Attribute to retrieve :return: (Any) The attribute """ def _getattr(obj: Any, attr: str) -> Any: return getattr(obj, attr, *args) return functools.reduce(_getattr, [obj] + attr.split('.')) def recursive_setattr(obj: Any, attr: str, val: Any) -> None: """ Recursive version of setattr taken from https://stackoverflow.com/questions/31174295 Ex: > MyObject.sub_object = SubObject(name='test') > recursive_setattr(MyObject, 'sub_object.name', 'hello') :param obj: (Any) :param attr: (str) Attribute to set :param val: (Any) New value of the attribute """ pre, _, post = attr.rpartition('.') return setattr(recursive_getattr(obj, pre) if pre else obj, post, val) def is_json_serializable(item: Any) -> bool: """ Test if an object is serializable into JSON :param item: (object) The object to be tested for JSON serialization. :return: (bool) True if object is JSON serializable, false otherwise. """ # Try with try-except struct. json_serializable = True try: _ = json.dumps(item) except TypeError: json_serializable = False return json_serializable def data_to_json(data: Dict[str, Any]) -> str: """ Turn data (class parameters) into a JSON string for storing :param data: (Dict[str, Any]) Dictionary of class parameters to be stored. Items that are not JSON serializable will be pickled with Cloudpickle and stored as bytearray in the JSON file :return: (str) JSON string of the data serialized. """ # First, check what elements can not be JSONfied, # and turn them into byte-strings serializable_data = {} for data_key, data_item in data.items(): # See if object is JSON serializable if is_json_serializable(data_item): # All good, store as it is serializable_data[data_key] = data_item else: # Not serializable, cloudpickle it into # bytes and convert to base64 string for storing. # Also store type of the class for consumption # from other languages/humans, so we have an # idea what was being stored. base64_encoded = base64.b64encode( cloudpickle.dumps(data_item) ).decode() # Use ":" to make sure we do # not override these keys # when we include variables of the object later cloudpickle_serialization = { ":type:": str(type(data_item)), ":serialized:": base64_encoded } # Add first-level JSON-serializable items of the # object for further details (but not deeper than this to # avoid deep nesting). # First we check that object has attributes (not all do, # e.g. numpy scalars) if hasattr(data_item, "__dict__") or isinstance(data_item, dict): # Take elements from __dict__ for custom classes item_generator = ( data_item.items if isinstance(data_item, dict) else data_item.__dict__.items ) for variable_name, variable_item in item_generator(): # Check if serializable. If not, just include the # string-representation of the object. if is_json_serializable(variable_item): cloudpickle_serialization[variable_name] = variable_item else: cloudpickle_serialization[variable_name] = str(variable_item) serializable_data[data_key] = cloudpickle_serialization json_string = json.dumps(serializable_data, indent=4) return json_string def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ Turn JSON serialization of class-parameters back into dictionary. :param json_string: (str) JSON serialization of the class-parameters that should be loaded. :param custom_objects: (dict) Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in `keras.models.load_model`. Useful when you have an object in file that can not be deserialized. :return: (dict) Loaded class parameters. """ if custom_objects is not None and not isinstance(custom_objects, dict): raise ValueError("custom_objects argument must be a dict or None") json_dict = json.loads(json_string) # This will be filled with deserialized data return_data = {} for data_key, data_item in json_dict.items(): if custom_objects is not None and data_key in custom_objects.keys(): # If item is provided in custom_objects, replace # the one from JSON with the one in custom_objects return_data[data_key] = custom_objects[data_key] elif isinstance(data_item, dict) and ":serialized:" in data_item.keys(): # If item is dictionary with ":serialized:" # key, this means it is serialized with cloudpickle. serialization = data_item[":serialized:"] # Try-except deserialization in case we run into # errors. If so, we can tell bit more information to # user. try: base64_object = base64.b64decode(serialization.encode()) deserialized_object = cloudpickle.loads(base64_object) except RuntimeError: warnings.warn(f"Could not deserialize object {data_key}. " + "Consider using `custom_objects` argument to replace " + "this object.") return_data[data_key] = deserialized_object else: # Read as it is return_data[data_key] = data_item return return_data