mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Fix saving on GPU - Loading on CPU (#45)
* removed policy from save, changed th.loads to map to device * found hack: catch pickle exception and trying th.load with mapping instead, otherwise raise exception with more information -> loading cuda on cpu raises exception -> leads to th.load with map being called * deleted todo * updated changelog * start of saving refactor * first working c * all tests pass, save refactored * - backwards compatibilty not always - make pytest all passing - make typing all passing * Fixes and simplify the save method * Remove unused param * Fix backward compat * Fix docstring
This commit is contained in:
parent
cc3b023533
commit
1f0dd60b97
7 changed files with 211 additions and 171 deletions
|
|
@ -16,6 +16,7 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fix loading model on CPU that were trained on GPU
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -16,13 +16,12 @@ MODEL_LIST = [
|
|||
SAC,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_save_load(model_class):
|
||||
"""
|
||||
Test if 'save' and 'load' saves and loads model correctly
|
||||
and if 'load_parameters' and 'get_policy_parameters' work correctly
|
||||
|
||||
|
||||
''warning does not test function of optimizer parameter load
|
||||
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
|
|
@ -38,16 +37,15 @@ def test_save_load(model_class):
|
|||
observations = np.squeeze(observations)
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(model.get_policy_parameters())
|
||||
opt_params = deepcopy(model.get_opt_parameters())
|
||||
params = deepcopy(model.policy.state_dict())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
|
||||
# Update model parameters with the new random values
|
||||
model.load_parameters(random_params, opt_params)
|
||||
model.policy.load_state_dict(random_params)
|
||||
|
||||
new_params = model.get_policy_parameters()
|
||||
new_params = model.policy.state_dict()
|
||||
# Check that all params are different now
|
||||
for k in params:
|
||||
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
||||
|
|
@ -63,33 +61,14 @@ def test_save_load(model_class):
|
|||
model = model_class.load("test_save", env=env)
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = model.get_policy_parameters()
|
||||
new_params = model.policy.state_dict()
|
||||
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for k in params:
|
||||
assert th.allclose(params[k], new_params[k]), "Model parameters not the same after save and load."
|
||||
|
||||
# check if optimizer params are still the same after load
|
||||
new_opt_params = model.get_opt_parameters()
|
||||
# check if keys are the same
|
||||
assert opt_params.keys() == new_opt_params.keys()
|
||||
# check if values are the same: only tested for Adam and RMSProp so far
|
||||
# comparing states not implemented so far. hashes of state_entries are not the same for equal tensors
|
||||
# comparing every sub_entry does not work because of bool value of Tensor with more than one value is ambiguous
|
||||
# so far only comparing param_lists
|
||||
for optimizer, opt_state in opt_params.items():
|
||||
for param_group_idx, param_group in enumerate(opt_state['param_groups']):
|
||||
for param_key, param_value in param_group.items():
|
||||
# don't know how to handle params correctly, therefore only check if we have the same amount
|
||||
if param_key == 'params':
|
||||
assert len(param_value) == len(
|
||||
new_opt_params[optimizer]['param_groups'][param_group_idx][param_key])
|
||||
else:
|
||||
assert param_value == new_opt_params[optimizer]['param_groups'][param_group_idx][param_key]
|
||||
for key in params:
|
||||
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
|
||||
# for i in range(len(selected_actionsselected_actions)):
|
||||
assert np.allclose(selected_actions, new_selected_actions)
|
||||
|
||||
# check if learn still works
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, upda
|
|||
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, sync_envs_normalization
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
from torchy_baselines.common.evaluation import evaluate_policy
|
||||
from torchy_baselines.common.save_util import data_to_json, json_to_data
|
||||
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
|
||||
|
||||
# TODO: define aliases, ex GymEnv = Union[gym.Env, VecEnv]
|
||||
if typing.TYPE_CHECKING:
|
||||
|
|
@ -49,11 +49,12 @@ class BaseRLModel(ABC):
|
|||
:param sde_sample_freq: Sample a new noise matrix every n steps when using SDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
policy: Type[BasePolicy],
|
||||
env: Union[gym.Env, VecEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
policy_kwargs : Dict[str, Any] = None,
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
support_multi_env: bool = False,
|
||||
|
|
@ -129,6 +130,14 @@ class BaseRLModel(ABC):
|
|||
raise ValueError("Error: the model does not support multiple envs requires a single vectorized"
|
||||
" environment.")
|
||||
|
||||
@abstractmethod
|
||||
def _setup_model(self) -> None:
|
||||
"""
|
||||
Setup model so state_dict can be loaded
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_eval_env(self, eval_env: Union[gym.Env, VecEnv, None]) -> Union[gym.Env, VecEnv, None]:
|
||||
"""
|
||||
Return the environment that will be used for evaluation.
|
||||
|
|
@ -252,29 +261,18 @@ class BaseRLModel(ABC):
|
|||
self.n_envs = env.num_envs
|
||||
self.env = env
|
||||
|
||||
def get_parameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Returns policy and optimizer parameters as a tuple
|
||||
Get the name of the torch variable that will be saved.
|
||||
`th.save` and `th.load` will be used with the right device
|
||||
instead of the default pickling strategy.
|
||||
|
||||
:return: policy_parameters, opt_parameters
|
||||
:return: (Tuple[List[str], List[str]])
|
||||
name of the variables with state dicts to save, name of additional torch tensors,
|
||||
"""
|
||||
return self.get_policy_parameters(), self.get_opt_parameters()
|
||||
state_dicts = ["policy"]
|
||||
|
||||
def get_policy_parameters(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current model policy parameters as dictionary of variable name -> tensors.
|
||||
|
||||
:return: Dictionary of variable name -> tensor of model's policy parameters.
|
||||
"""
|
||||
return self.policy.state_dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_opt_parameters(self)-> Dict[str, Any]:
|
||||
"""
|
||||
Get current model optimizer parameters as dictionary of variable names -> tensors
|
||||
:return: (dict) Dictionary of variable name -> tensor of model's optimizer parameters
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
return state_dicts, []
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, total_timesteps: int,
|
||||
|
|
@ -316,21 +314,6 @@ class BaseRLModel(ABC):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_parameters(self, load_dict: Dict[str, Any], opt_params: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Load model parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
If opt_params are given this does also load agent's optimizer-parameters,
|
||||
but can only be handled in child classes.
|
||||
|
||||
|
||||
:param load_dict: dict of parameters from model.state_dict()
|
||||
:param opt_params: dict of optimizer state_dicts should be handled in child_class
|
||||
"""
|
||||
if opt_params is not None:
|
||||
raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class")
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_path: str, env: Union[gym.Env, VecEnv, None] = None, **kwargs):
|
||||
"""
|
||||
|
|
@ -341,7 +324,7 @@ class BaseRLModel(ABC):
|
|||
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
||||
:param kwargs: extra arguments to change the model when loading
|
||||
"""
|
||||
data, params, opt_params = cls._load_from_file(load_path)
|
||||
data, params, tensors = cls._load_from_file(load_path)
|
||||
|
||||
if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
|
||||
raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs."
|
||||
|
|
@ -359,12 +342,25 @@ class BaseRLModel(ABC):
|
|||
|
||||
# first create model, but only setup if a env was given
|
||||
# noinspection PyArgumentList
|
||||
model = cls(policy=data["policy_class"], env=env, _init_setup_model=env is not None)
|
||||
model = cls(policy=data["policy_class"], env=env, device='auto', _init_setup_model=env is not None)
|
||||
|
||||
# load parameters
|
||||
model.__dict__.update(data)
|
||||
model.__dict__.update(kwargs)
|
||||
model.load_parameters(params, opt_params)
|
||||
if not hasattr(model, "_setup_model") and len(params) > 0:
|
||||
raise NotImplementedError("loading was executed on a model that has no means to create the policies")
|
||||
model._setup_model()
|
||||
|
||||
# put state_dicts back in place
|
||||
for name in params:
|
||||
attr = recursive_getattr(model, name)
|
||||
attr.load_state_dict(params[name])
|
||||
|
||||
# put tensors back in place
|
||||
if tensors is not None:
|
||||
for name in tensors:
|
||||
recursive_setattr(model, name, tensors[name])
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -374,8 +370,8 @@ class BaseRLModel(ABC):
|
|||
:param load_path: Where to load the model from
|
||||
:param load_data: Whether we should load and return data
|
||||
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
||||
:return: (dict),(dict),(dict) Class parameters, model parameters (state_dict)
|
||||
and dict of optimizer parameters (dict of state_dict)
|
||||
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
|
||||
and dict of extra tensors
|
||||
"""
|
||||
# Check if file exists if load_path is a string
|
||||
if isinstance(load_path, str):
|
||||
|
|
@ -385,6 +381,12 @@ class BaseRLModel(ABC):
|
|||
else:
|
||||
raise ValueError(f"Error: the file {load_path} could not be found")
|
||||
|
||||
# set device to cpu if cuda is not available
|
||||
if th.cuda.is_available():
|
||||
device = th.device('cuda')
|
||||
else:
|
||||
device = th.device('cpu')
|
||||
|
||||
# Open the zip archive and load data
|
||||
try:
|
||||
with zipfile.ZipFile(load_path, "r") as archive:
|
||||
|
|
@ -393,31 +395,32 @@ class BaseRLModel(ABC):
|
|||
# zip archive, assume they were stored
|
||||
# as None (_save_to_file_zip allows this).
|
||||
data = None
|
||||
params = None
|
||||
opt_params = None
|
||||
tensors = None
|
||||
params = {}
|
||||
|
||||
if "data" in namelist and load_data:
|
||||
# Load class parameters and convert to string
|
||||
json_data = archive.read("data").decode()
|
||||
data = json_to_data(json_data)
|
||||
data = json_to_data(json_data, device)
|
||||
|
||||
if "params.pth" in namelist:
|
||||
# Load parameters with build in torch function
|
||||
with archive.open("params.pth", mode="r") as param_file:
|
||||
# File has to be seekable, but param_file is not, so load in BytesIO first
|
||||
if "tensors.pth" in namelist and load_data:
|
||||
# Load extra tensors
|
||||
with archive.open('tensors.pth', mode="r") as tensor_file:
|
||||
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
||||
# fixed in python >= 3.7
|
||||
file_content = io.BytesIO()
|
||||
file_content.write(param_file.read())
|
||||
file_content.write(tensor_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
params = th.load(file_content)
|
||||
# load the parameters with the right `map_location`
|
||||
tensors = th.load(file_content, map_location=device)
|
||||
|
||||
# check for all other .pth files
|
||||
other_files = [file_name for file_name in namelist if
|
||||
os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"]
|
||||
os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"]
|
||||
# if there are any other files which end with .pth and aren't "params.pth"
|
||||
# assume that they each are optimizer parameters
|
||||
if len(other_files) > 0:
|
||||
opt_params = dict()
|
||||
for file_path in other_files:
|
||||
with archive.open(file_path, mode="r") as opt_param_file:
|
||||
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
||||
|
|
@ -426,13 +429,27 @@ class BaseRLModel(ABC):
|
|||
file_content.write(opt_param_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# save the parameters in dict with file name but trim file ending
|
||||
opt_params[os.path.splitext(file_path)[0]] = th.load(file_content)
|
||||
# load the parameters with the right `map_location`
|
||||
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
|
||||
|
||||
# for backward compatibility
|
||||
if params.get('params') is not None:
|
||||
params_copy = {}
|
||||
for name in params:
|
||||
if name == 'params':
|
||||
params_copy['policy'] = params[name]
|
||||
elif name == 'opt':
|
||||
params_copy['policy.optimizer'] = params[name]
|
||||
# Special case for SAC
|
||||
elif name == 'ent_coef_optimizer':
|
||||
params_copy[name] = params[name]
|
||||
else:
|
||||
params_copy[name + '.optimizer'] = params[name]
|
||||
params = params_copy
|
||||
except zipfile.BadZipFile:
|
||||
# load_path wasn't a zip file
|
||||
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
|
||||
|
||||
return data, params, opt_params
|
||||
return data, params, tensors
|
||||
|
||||
def set_random_seed(self, seed: Optional[int] = None) -> None:
|
||||
"""
|
||||
|
|
@ -654,15 +671,15 @@ class BaseRLModel(ABC):
|
|||
|
||||
@staticmethod
|
||||
def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None,
|
||||
params: Dict[str, Any] = None, opt_params: Dict[str, Any] = None) -> None:
|
||||
params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None:
|
||||
"""
|
||||
Save model to a zip archive.
|
||||
|
||||
:param save_path: Where to store the model
|
||||
:param data: Class parameters being stored
|
||||
:param params: Model parameters being stored expected to be state_dict
|
||||
:param opt_params: Optimizer parameters being stored expected to contain an entry for every
|
||||
optimizer with its name and the state_dict
|
||||
:param params: Model parameters being stored expected to contain an entry for every
|
||||
state_dict with its name and the state_dict
|
||||
:param tensors: Extra tensor variables expected to contain name and value of tensors
|
||||
"""
|
||||
|
||||
# data/params can be None, so do not
|
||||
|
|
@ -683,23 +700,22 @@ class BaseRLModel(ABC):
|
|||
# Do not try to save "None" elements
|
||||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if tensors is not None:
|
||||
with archive.open('tensors.pth', mode="w") as tensors_file:
|
||||
th.save(tensors, tensors_file)
|
||||
if params is not None:
|
||||
with archive.open('params.pth', mode="w") as param_file:
|
||||
th.save(params, param_file)
|
||||
if opt_params is not None:
|
||||
for file_name, dict_ in opt_params.items():
|
||||
with archive.open(file_name + '.pth', mode="w") as opt_param_file:
|
||||
th.save(dict_, opt_param_file)
|
||||
for file_name, dict_ in params.items():
|
||||
with archive.open(file_name + '.pth', mode="w") as param_file:
|
||||
th.save(dict_, param_file)
|
||||
|
||||
@staticmethod
|
||||
def excluded_save_params() -> List[str]:
|
||||
def excluded_save_params(self) -> List[str]:
|
||||
"""
|
||||
Returns the names of the parameters that should be excluded by default
|
||||
when saving the model.
|
||||
|
||||
:return: ([str]) List of parameters that should be excluded from save
|
||||
"""
|
||||
return ["env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]
|
||||
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]
|
||||
|
||||
def save(self, path: str, exclude: Optional[List[str]] = None, include: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
|
|
@ -717,18 +733,40 @@ class BaseRLModel(ABC):
|
|||
else:
|
||||
# append standard exclude params to the given params
|
||||
exclude.extend([param for param in self.excluded_save_params() if param not in exclude])
|
||||
|
||||
# do not exclude params if they are specifically included
|
||||
if include is not None:
|
||||
exclude = [param_name for param_name in exclude if param_name not in include]
|
||||
|
||||
# remove parameter entries of parameters which are to be excluded
|
||||
state_dicts_names, tensors_names = self.get_torch_variables()
|
||||
# any params that are in the save vars must not be saved by data
|
||||
torch_variables = state_dicts_names + tensors_names
|
||||
for torch_var in torch_variables:
|
||||
# we need to get only the name of the top most module as we'll remove that
|
||||
var_name = torch_var.split('.')[0]
|
||||
exclude.append(var_name)
|
||||
|
||||
# Remove parameter entries of parameters which are to be excluded
|
||||
for param_name in exclude:
|
||||
if param_name in data:
|
||||
data.pop(param_name, None)
|
||||
|
||||
params_to_save = self.get_policy_parameters()
|
||||
opt_params_to_save = self.get_opt_parameters()
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)
|
||||
# Build dict of tensor variables
|
||||
tensors = None
|
||||
if tensors_names is not None:
|
||||
tensors = {}
|
||||
for name in tensors_names:
|
||||
attr = recursive_getattr(self, name)
|
||||
tensors[name] = attr
|
||||
|
||||
# Build dict of state_dicts
|
||||
params_to_save = {}
|
||||
for name in state_dicts_names:
|
||||
attr = recursive_getattr(self, name)
|
||||
# Retrieve state dict
|
||||
params_to_save[name] = attr.state_dict()
|
||||
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)
|
||||
|
||||
def _eval_policy(self, eval_freq: int, eval_env: int, n_eval_episodes: int,
|
||||
timesteps_since_eval: int, render: bool = False, deterministic: bool = True) -> int:
|
||||
|
|
|
|||
|
|
@ -2,15 +2,51 @@
|
|||
Save util taken from stable_baselines
|
||||
used to serialize data (class parameters) of model classes
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import base64
|
||||
import pickle
|
||||
import functools
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
import torch as th
|
||||
import cloudpickle
|
||||
import warnings
|
||||
|
||||
|
||||
def is_json_serializable(item):
|
||||
def recursive_getattr(obj: Any, attr: str, *args) -> Any:
|
||||
"""
|
||||
Recursive version of getattr
|
||||
taken from https://stackoverflow.com/questions/31174295
|
||||
|
||||
Ex:
|
||||
> MyObject.sub_object = SubObject(name='test')
|
||||
> recursive_getattr(MyObject, 'sub_object.name') # return test
|
||||
:param obj: (Any)
|
||||
:param attr: (str) Attribute to retrieve
|
||||
:return: (Any) The attribute
|
||||
"""
|
||||
def _getattr(obj: Any, attr: str) -> Any:
|
||||
return getattr(obj, attr, *args)
|
||||
|
||||
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
||||
|
||||
|
||||
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
|
||||
"""
|
||||
Recursive version of setattr
|
||||
taken from https://stackoverflow.com/questions/31174295
|
||||
|
||||
Ex:
|
||||
> MyObject.sub_object = SubObject(name='test')
|
||||
> recursive_setattr(MyObject, 'sub_object.name', 'hello')
|
||||
:param obj: (Any)
|
||||
:param attr: (str) Attribute to set
|
||||
:param val: (Any) New value of the attribute
|
||||
"""
|
||||
pre, _, post = attr.rpartition('.')
|
||||
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
|
||||
|
||||
|
||||
def is_json_serializable(item: Any) -> bool:
|
||||
"""
|
||||
Test if an object is serializable into JSON
|
||||
|
||||
|
|
@ -26,11 +62,11 @@ def is_json_serializable(item):
|
|||
return json_serializable
|
||||
|
||||
|
||||
def data_to_json(data):
|
||||
def data_to_json(data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Turn data (class parameters) into a JSON string for storing
|
||||
|
||||
:param data: (Dict) Dictionary of class parameters to be
|
||||
:param data: (Dict[str, Any]) Dictionary of class parameters to be
|
||||
stored. Items that are not JSON serializable will be
|
||||
pickled with Cloudpickle and stored as bytearray in
|
||||
the JSON file
|
||||
|
|
@ -85,12 +121,15 @@ def data_to_json(data):
|
|||
return json_string
|
||||
|
||||
|
||||
def json_to_data(json_string, custom_objects=None):
|
||||
def json_to_data(json_string: str,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Turn JSON serialization of class-parameters back into dictionary.
|
||||
|
||||
:param json_string: (str) JSON serialization of the class-parameters
|
||||
that should be loaded.
|
||||
:param device: torch.device device to which the data should be mapped if errors occur
|
||||
:param custom_objects: (dict) Dictionary of objects to replace
|
||||
upon loading. If a variable is present in this dictionary as a
|
||||
key, it will not be deserialized and the corresponding item
|
||||
|
|
@ -118,15 +157,12 @@ def json_to_data(json_string, custom_objects=None):
|
|||
# errors. If so, we can tell bit more information to
|
||||
# user.
|
||||
try:
|
||||
deserialized_object = cloudpickle.loads(
|
||||
base64.b64decode(serialization.encode())
|
||||
)
|
||||
except pickle.UnpicklingError:
|
||||
raise RuntimeError(
|
||||
f"Could not deserialize object {data_key}. " +
|
||||
"Consider using `custom_objects` argument to replace " +
|
||||
"this object."
|
||||
)
|
||||
base64_object = base64.b64decode(serialization.encode())
|
||||
deserialized_object = cloudpickle.loads(base64_object)
|
||||
except RuntimeError:
|
||||
warnings.warn(f"Could not deserialize object {data_key}. " +
|
||||
"Consider using `custom_objects` argument to replace " +
|
||||
"this object.")
|
||||
return_data[data_key] = deserialized_object
|
||||
else:
|
||||
# Read as it is
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
|
|
@ -69,6 +70,7 @@ class PPO(BaseRLModel):
|
|||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(self, policy, env, learning_rate=3e-4,
|
||||
n_steps=2048, batch_size=64, n_epochs=10,
|
||||
gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None,
|
||||
|
|
@ -310,22 +312,10 @@ class PPO(BaseRLModel):
|
|||
|
||||
return self
|
||||
|
||||
def get_opt_parameters(self):
|
||||
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Returns a dict of all the optimizers and their parameters
|
||||
cf base class
|
||||
"""
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
||||
:return: (dict) of optimizer names and their state_dict
|
||||
"""
|
||||
return {"opt": self.policy.optimizer.state_dict()}
|
||||
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
"""
|
||||
Load model parameters and optimizer parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
This does not load agent's hyper-parameters.
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts
|
||||
"""
|
||||
self.policy.optimizer.load_state_dict(opt_params["opt"])
|
||||
self.policy.load_state_dict(load_dict)
|
||||
return state_dicts, []
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
|
@ -128,7 +130,7 @@ class SAC(BaseRLModel):
|
|||
# Force conversion to float
|
||||
# this will throw an error if a malformed string (different from 'auto')
|
||||
# is passed
|
||||
self.ent_coef = th.tensor(float(self.ent_coef)).to(self.device)
|
||||
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
|
||||
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
|
|
@ -199,7 +201,7 @@ class SAC(BaseRLModel):
|
|||
ent_coef = th.exp(self.log_ent_coef.detach())
|
||||
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
|
||||
else:
|
||||
ent_coef = self.ent_coef
|
||||
ent_coef = self.ent_coef_tensor
|
||||
|
||||
# Optimize entropy coefficient, also called
|
||||
# entropy temperature or alpha in the paper
|
||||
|
|
@ -294,28 +296,24 @@ class SAC(BaseRLModel):
|
|||
|
||||
return self
|
||||
|
||||
def get_opt_parameters(self):
|
||||
def excluded_save_params(self) -> List[str]:
|
||||
"""
|
||||
Returns a dict of all the optimizers and their parameters
|
||||
Returns the names of the parameters that should be excluded by default
|
||||
when saving the model.
|
||||
|
||||
:return: (Dict) of optimizer names and their state_dict
|
||||
:return: (List[str]) List of parameters that should be excluded from save
|
||||
"""
|
||||
opt_dict = {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()}
|
||||
# Exclude aliases
|
||||
return super(SAC, self).excluded_save_params() + ["actor", "critic", "critic_target"]
|
||||
|
||||
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
cf base class
|
||||
"""
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
saved_tensors = ['log_ent_coef']
|
||||
if self.ent_coef_optimizer is not None:
|
||||
opt_dict.update({"ent_coef_optimizer": self.ent_coef_optimizer.state_dict()})
|
||||
return opt_dict
|
||||
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
"""
|
||||
Load model parameters and optimizer parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
This does not load agent's hyper-parameters.
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
|
||||
"""
|
||||
self.actor.optimizer.load_state_dict(opt_params["actor"])
|
||||
self.critic.optimizer.load_state_dict(opt_params["critic"])
|
||||
if "ent_coef_optimizer" in opt_params:
|
||||
self.ent_coef_optimizer.load_state_dict(opt_params["ent_coef_optimizer"])
|
||||
self.policy.load_state_dict(load_dict)
|
||||
state_dicts.append('ent_coef_optimizer')
|
||||
else:
|
||||
saved_tensors.append('ent_coef_tensor')
|
||||
return state_dicts, saved_tensors
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
|
@ -296,23 +298,19 @@ class TD3(BaseRLModel):
|
|||
|
||||
return self
|
||||
|
||||
def get_opt_parameters(self):
|
||||
def excluded_save_params(self) -> List[str]:
|
||||
"""
|
||||
Returns a dict of all the optimizers and their parameters
|
||||
Returns the names of the parameters that should be excluded by default
|
||||
when saving the model.
|
||||
|
||||
:return: (Dict) of optimizer names and their state_dict
|
||||
:return: (List[str]) List of parameters that should be excluded from save
|
||||
"""
|
||||
return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()}
|
||||
# Exclude aliases
|
||||
return super(TD3, self).excluded_save_params() + ["actor", "critic", "vf_net", "actor_target", "critic_target"]
|
||||
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Load model parameters and optimizer parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
This does not load agent's hyper-parameters.
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
|
||||
cf base class
|
||||
"""
|
||||
self.actor.optimizer.load_state_dict(opt_params["actor"])
|
||||
self.critic.optimizer.load_state_dict(opt_params["critic"])
|
||||
self.policy.load_state_dict(load_dict)
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
return state_dicts, []
|
||||
|
|
|
|||
Loading…
Reference in a new issue