mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-26 03:01:19 +00:00
* Split torch module code into torch_layers file * Updated reference to CNN * Change 'CxWxH' to 'CxHxW', as per common notion * Fix missing import in policies.py * Move PPOPolicy to OnlineActorCriticPolicy * Create OnPolicyRLModel from PPO, and make A2C and PPO inherit * Update A2C optimizer comment * Clean weight init scales for clarity * Fix A2C log_interval default parameter * Rename 'progress' to 'progress_remaining * Rename 'Models' to 'Algorithms' * Rename 'OnlineActorCriticPolicy' to 'ActorCriticPolicy' * Move static functions out from BaseAlgorithm * Move on/off_policy base algorithms to their own files * Add files for A2C/PPO * Fix docs * Fix pytype * Update documentation on OnPolicyAlgorithm * Add proper doctstring for on_policy rollout gathering * Add bit clarification on the mlppolicy/cnnpolicy naming * Move static function is_vectorized_policies to utils.py * Checking docstrings, pep8 fixes * Update changelog * Clean changelog * Remove policy warnings for sac/td3 * Add monitor_wrapper for OnPolicyAlgorithm. Clean tb logging variables. Add parameter keywords to OffPolicyAlgorithm super init Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
286 lines
12 KiB
Python
286 lines
12 KiB
Python
"""
|
|
Save util taken from stable_baselines
|
|
used to serialize data (class parameters) of model classes
|
|
"""
|
|
import os
|
|
import io
|
|
import json
|
|
import base64
|
|
import functools
|
|
from typing import Dict, Any, Tuple, Optional
|
|
import warnings
|
|
import zipfile
|
|
|
|
import torch as th
|
|
import cloudpickle
|
|
|
|
from stable_baselines3.common.type_aliases import TensorDict
|
|
from stable_baselines3.common.utils import get_device
|
|
|
|
|
|
def recursive_getattr(obj: Any, attr: str, *args) -> Any:
|
|
"""
|
|
Recursive version of getattr
|
|
taken from https://stackoverflow.com/questions/31174295
|
|
|
|
Ex:
|
|
> MyObject.sub_object = SubObject(name='test')
|
|
> recursive_getattr(MyObject, 'sub_object.name') # return test
|
|
:param obj: (Any)
|
|
:param attr: (str) Attribute to retrieve
|
|
:return: (Any) The attribute
|
|
"""
|
|
def _getattr(obj: Any, attr: str) -> Any:
|
|
return getattr(obj, attr, *args)
|
|
|
|
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
|
|
|
|
|
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
|
|
"""
|
|
Recursive version of setattr
|
|
taken from https://stackoverflow.com/questions/31174295
|
|
|
|
Ex:
|
|
> MyObject.sub_object = SubObject(name='test')
|
|
> recursive_setattr(MyObject, 'sub_object.name', 'hello')
|
|
:param obj: (Any)
|
|
:param attr: (str) Attribute to set
|
|
:param val: (Any) New value of the attribute
|
|
"""
|
|
pre, _, post = attr.rpartition('.')
|
|
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
|
|
|
|
|
|
def is_json_serializable(item: Any) -> bool:
|
|
"""
|
|
Test if an object is serializable into JSON
|
|
|
|
:param item: (object) The object to be tested for JSON serialization.
|
|
:return: (bool) True if object is JSON serializable, false otherwise.
|
|
"""
|
|
# Try with try-except struct.
|
|
json_serializable = True
|
|
try:
|
|
_ = json.dumps(item)
|
|
except TypeError:
|
|
json_serializable = False
|
|
return json_serializable
|
|
|
|
|
|
def data_to_json(data: Dict[str, Any]) -> str:
|
|
"""
|
|
Turn data (class parameters) into a JSON string for storing
|
|
|
|
:param data: (Dict[str, Any]) Dictionary of class parameters to be
|
|
stored. Items that are not JSON serializable will be
|
|
pickled with Cloudpickle and stored as bytearray in
|
|
the JSON file
|
|
:return: (str) JSON string of the data serialized.
|
|
"""
|
|
# First, check what elements can not be JSONfied,
|
|
# and turn them into byte-strings
|
|
serializable_data = {}
|
|
for data_key, data_item in data.items():
|
|
# See if object is JSON serializable
|
|
if is_json_serializable(data_item):
|
|
# All good, store as it is
|
|
serializable_data[data_key] = data_item
|
|
else:
|
|
# Not serializable, cloudpickle it into
|
|
# bytes and convert to base64 string for storing.
|
|
# Also store type of the class for consumption
|
|
# from other languages/humans, so we have an
|
|
# idea what was being stored.
|
|
base64_encoded = base64.b64encode(
|
|
cloudpickle.dumps(data_item)
|
|
).decode()
|
|
|
|
# Use ":" to make sure we do
|
|
# not override these keys
|
|
# when we include variables of the object later
|
|
cloudpickle_serialization = {
|
|
":type:": str(type(data_item)),
|
|
":serialized:": base64_encoded
|
|
}
|
|
|
|
# Add first-level JSON-serializable items of the
|
|
# object for further details (but not deeper than this to
|
|
# avoid deep nesting).
|
|
# First we check that object has attributes (not all do,
|
|
# e.g. numpy scalars)
|
|
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
|
|
# Take elements from __dict__ for custom classes
|
|
item_generator = (
|
|
data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
|
|
)
|
|
for variable_name, variable_item in item_generator():
|
|
# Check if serializable. If not, just include the
|
|
# string-representation of the object.
|
|
if is_json_serializable(variable_item):
|
|
cloudpickle_serialization[variable_name] = variable_item
|
|
else:
|
|
cloudpickle_serialization[variable_name] = str(variable_item)
|
|
|
|
serializable_data[data_key] = cloudpickle_serialization
|
|
json_string = json.dumps(serializable_data, indent=4)
|
|
return json_string
|
|
|
|
|
|
def json_to_data(json_string: str,
|
|
custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Turn JSON serialization of class-parameters back into dictionary.
|
|
|
|
:param json_string: (str) JSON serialization of the class-parameters
|
|
that should be loaded.
|
|
:param custom_objects: (dict) Dictionary of objects to replace
|
|
upon loading. If a variable is present in this dictionary as a
|
|
key, it will not be deserialized and the corresponding item
|
|
will be used instead. Similar to custom_objects in
|
|
`keras.models.load_model`. Useful when you have an object in
|
|
file that can not be deserialized.
|
|
:return: (dict) Loaded class parameters.
|
|
"""
|
|
if custom_objects is not None and not isinstance(custom_objects, dict):
|
|
raise ValueError("custom_objects argument must be a dict or None")
|
|
|
|
json_dict = json.loads(json_string)
|
|
# This will be filled with deserialized data
|
|
return_data = {}
|
|
for data_key, data_item in json_dict.items():
|
|
if custom_objects is not None and data_key in custom_objects.keys():
|
|
# If item is provided in custom_objects, replace
|
|
# the one from JSON with the one in custom_objects
|
|
return_data[data_key] = custom_objects[data_key]
|
|
elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
|
|
# If item is dictionary with ":serialized:"
|
|
# key, this means it is serialized with cloudpickle.
|
|
serialization = data_item[":serialized:"]
|
|
# Try-except deserialization in case we run into
|
|
# errors. If so, we can tell bit more information to
|
|
# user.
|
|
try:
|
|
base64_object = base64.b64decode(serialization.encode())
|
|
deserialized_object = cloudpickle.loads(base64_object)
|
|
except RuntimeError:
|
|
warnings.warn(f"Could not deserialize object {data_key}. " +
|
|
"Consider using `custom_objects` argument to replace " +
|
|
"this object.")
|
|
return_data[data_key] = deserialized_object
|
|
else:
|
|
# Read as it is
|
|
return_data[data_key] = data_item
|
|
return return_data
|
|
|
|
|
|
def save_to_zip_file(save_path: str, data: Dict[str, Any] = None,
|
|
params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None:
|
|
"""
|
|
Save a model to a zip archive.
|
|
|
|
:param save_path: Where to store the model.
|
|
:param data: Class parameters being stored.
|
|
:param params: Model parameters being stored expected to contain an entry for every
|
|
state_dict with its name and the state_dict.
|
|
:param tensors: Extra tensor variables expected to contain name and value of tensors
|
|
"""
|
|
|
|
# data/params can be None, so do not
|
|
# try to serialize them blindly
|
|
if data is not None:
|
|
serialized_data = data_to_json(data)
|
|
|
|
# Check postfix if save_path is a string
|
|
if isinstance(save_path, str):
|
|
_, ext = os.path.splitext(save_path)
|
|
if ext == "":
|
|
save_path += ".zip"
|
|
|
|
# Create a zip-archive and write our objects
|
|
# there. This works when save_path is either
|
|
# str or a file-like
|
|
with zipfile.ZipFile(save_path, "w") as archive:
|
|
# Do not try to save "None" elements
|
|
if data is not None:
|
|
archive.writestr("data", serialized_data)
|
|
if tensors is not None:
|
|
with archive.open('tensors.pth', mode="w") as tensors_file:
|
|
th.save(tensors, tensors_file)
|
|
if params is not None:
|
|
for file_name, dict_ in params.items():
|
|
with archive.open(file_name + '.pth', mode="w") as param_file:
|
|
th.save(dict_, param_file)
|
|
|
|
|
|
def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]],
|
|
Optional[TensorDict],
|
|
Optional[TensorDict]]):
|
|
"""
|
|
Load model data from a .zip archive
|
|
|
|
:param load_path: Where to load the model from
|
|
:param load_data: Whether we should load and return data
|
|
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
|
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
|
|
and dict of extra tensors
|
|
"""
|
|
# Check if file exists if load_path is a string
|
|
if isinstance(load_path, str):
|
|
if not os.path.exists(load_path):
|
|
if os.path.exists(load_path + ".zip"):
|
|
load_path += ".zip"
|
|
else:
|
|
raise ValueError(f"Error: the file {load_path} could not be found")
|
|
|
|
# set device to cpu if cuda is not available
|
|
device = get_device()
|
|
|
|
# Open the zip archive and load data
|
|
try:
|
|
with zipfile.ZipFile(load_path, "r") as archive:
|
|
namelist = archive.namelist()
|
|
# If data or parameters is not in the
|
|
# zip archive, assume they were stored
|
|
# as None (_save_to_file_zip allows this).
|
|
data = None
|
|
tensors = None
|
|
params = {}
|
|
|
|
if "data" in namelist and load_data:
|
|
# Load class parameters and convert to string
|
|
json_data = archive.read("data").decode()
|
|
data = json_to_data(json_data)
|
|
|
|
if "tensors.pth" in namelist and load_data:
|
|
# Load extra tensors
|
|
with archive.open('tensors.pth', mode="r") as tensor_file:
|
|
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
|
# fixed in python >= 3.7
|
|
file_content = io.BytesIO()
|
|
file_content.write(tensor_file.read())
|
|
# go to start of file
|
|
file_content.seek(0)
|
|
# load the parameters with the right ``map_location``
|
|
tensors = th.load(file_content, map_location=device)
|
|
|
|
# check for all other .pth files
|
|
other_files = [file_name for file_name in namelist if
|
|
os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"]
|
|
# if there are any other files which end with .pth and aren't "params.pth"
|
|
# assume that they each are optimizer parameters
|
|
if len(other_files) > 0:
|
|
for file_path in other_files:
|
|
with archive.open(file_path, mode="r") as opt_param_file:
|
|
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
|
# fixed in python >= 3.7
|
|
file_content = io.BytesIO()
|
|
file_content.write(opt_param_file.read())
|
|
# go to start of file
|
|
file_content.seek(0)
|
|
# load the parameters with the right ``map_location``
|
|
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
|
|
except zipfile.BadZipFile:
|
|
# load_path wasn't a zip file
|
|
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
|
|
return data, params, tensors
|