diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4f95650..958c78a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,6 +16,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fix loading model on CPU that were trained on GPU Deprecations: ^^^^^^^^^^^^^ diff --git a/tests/test_save_load.py b/tests/test_save_load.py index f4adc0c..b4dd7b7 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -16,13 +16,12 @@ MODEL_LIST = [ SAC, ] - @pytest.mark.parametrize("model_class", MODEL_LIST) def test_save_load(model_class): """ Test if 'save' and 'load' saves and loads model correctly and if 'load_parameters' and 'get_policy_parameters' work correctly - + ''warning does not test function of optimizer parameter load :param model_class: (BaseRLModel) A RL model @@ -38,16 +37,15 @@ def test_save_load(model_class): observations = np.squeeze(observations) # Get dictionary of current parameters - params = deepcopy(model.get_policy_parameters()) - opt_params = deepcopy(model.get_opt_parameters()) + params = deepcopy(model.policy.state_dict()) # Modify all parameters to be random values random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) # Update model parameters with the new random values - model.load_parameters(random_params, opt_params) + model.policy.load_state_dict(random_params) - new_params = model.get_policy_parameters() + new_params = model.policy.state_dict() # Check that all params are different now for k in params: assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." @@ -63,33 +61,14 @@ def test_save_load(model_class): model = model_class.load("test_save", env=env) # check if params are still the same after load - new_params = model.get_policy_parameters() + new_params = model.policy.state_dict() # Check that all params are the same as before save load procedure now - for k in params: - assert th.allclose(params[k], new_params[k]), "Model parameters not the same after save and load." - - # check if optimizer params are still the same after load - new_opt_params = model.get_opt_parameters() - # check if keys are the same - assert opt_params.keys() == new_opt_params.keys() - # check if values are the same: only tested for Adam and RMSProp so far - # comparing states not implemented so far. hashes of state_entries are not the same for equal tensors - # comparing every sub_entry does not work because of bool value of Tensor with more than one value is ambiguous - # so far only comparing param_lists - for optimizer, opt_state in opt_params.items(): - for param_group_idx, param_group in enumerate(opt_state['param_groups']): - for param_key, param_value in param_group.items(): - # don't know how to handle params correctly, therefore only check if we have the same amount - if param_key == 'params': - assert len(param_value) == len( - new_opt_params[optimizer]['param_groups'][param_group_idx][param_key]) - else: - assert param_value == new_opt_params[optimizer]['param_groups'][param_group_idx][param_key] + for key in params: + assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load." # check if model still selects the same actions new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations] - # for i in range(len(selected_actionsselected_actions)): assert np.allclose(selected_actions, new_selected_actions) # check if learn still works diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 563e419..41743b4 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -17,7 +17,7 @@ from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, upda from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, sync_envs_normalization from torchy_baselines.common.monitor import Monitor from torchy_baselines.common.evaluation import evaluate_policy -from torchy_baselines.common.save_util import data_to_json, json_to_data +from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr # TODO: define aliases, ex GymEnv = Union[gym.Env, VecEnv] if typing.TYPE_CHECKING: @@ -49,11 +49,12 @@ class BaseRLModel(ABC): :param sde_sample_freq: Sample a new noise matrix every n steps when using SDE Default: -1 (only sample at the beginning of the rollout) """ + def __init__(self, policy: Type[BasePolicy], env: Union[gym.Env, VecEnv, str], policy_base: Type[BasePolicy], - policy_kwargs : Dict[str, Any] = None, + policy_kwargs: Dict[str, Any] = None, verbose: int = 0, device: Union[th.device, str] = 'auto', support_multi_env: bool = False, @@ -129,6 +130,14 @@ class BaseRLModel(ABC): raise ValueError("Error: the model does not support multiple envs requires a single vectorized" " environment.") + @abstractmethod + def _setup_model(self) -> None: + """ + Setup model so state_dict can be loaded + + """ + raise NotImplementedError() + def _get_eval_env(self, eval_env: Union[gym.Env, VecEnv, None]) -> Union[gym.Env, VecEnv, None]: """ Return the environment that will be used for evaluation. @@ -252,29 +261,18 @@ class BaseRLModel(ABC): self.n_envs = env.num_envs self.env = env - def get_parameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def get_torch_variables(self) -> Tuple[List[str], List[str]]: """ - Returns policy and optimizer parameters as a tuple + Get the name of the torch variable that will be saved. + `th.save` and `th.load` will be used with the right device + instead of the default pickling strategy. - :return: policy_parameters, opt_parameters + :return: (Tuple[List[str], List[str]]) + name of the variables with state dicts to save, name of additional torch tensors, """ - return self.get_policy_parameters(), self.get_opt_parameters() + state_dicts = ["policy"] - def get_policy_parameters(self) -> Dict[str, Any]: - """ - Get current model policy parameters as dictionary of variable name -> tensors. - - :return: Dictionary of variable name -> tensor of model's policy parameters. - """ - return self.policy.state_dict() - - @abstractmethod - def get_opt_parameters(self)-> Dict[str, Any]: - """ - Get current model optimizer parameters as dictionary of variable names -> tensors - :return: (dict) Dictionary of variable name -> tensor of model's optimizer parameters - """ - raise NotImplementedError() + return state_dicts, [] @abstractmethod def learn(self, total_timesteps: int, @@ -316,21 +314,6 @@ class BaseRLModel(ABC): """ raise NotImplementedError() - def load_parameters(self, load_dict: Dict[str, Any], opt_params: Dict[str, Any]) -> None: - """ - Load model parameters from a dictionary - load_dict should contain all keys from torch.model.state_dict() - If opt_params are given this does also load agent's optimizer-parameters, - but can only be handled in child classes. - - - :param load_dict: dict of parameters from model.state_dict() - :param opt_params: dict of optimizer state_dicts should be handled in child_class - """ - if opt_params is not None: - raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class") - self.policy.load_state_dict(load_dict) - @classmethod def load(cls, load_path: str, env: Union[gym.Env, VecEnv, None] = None, **kwargs): """ @@ -341,7 +324,7 @@ class BaseRLModel(ABC): (can be None if you only need prediction from a trained model) has priority over any saved environment :param kwargs: extra arguments to change the model when loading """ - data, params, opt_params = cls._load_from_file(load_path) + data, params, tensors = cls._load_from_file(load_path) if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs." @@ -359,12 +342,25 @@ class BaseRLModel(ABC): # first create model, but only setup if a env was given # noinspection PyArgumentList - model = cls(policy=data["policy_class"], env=env, _init_setup_model=env is not None) + model = cls(policy=data["policy_class"], env=env, device='auto', _init_setup_model=env is not None) # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) - model.load_parameters(params, opt_params) + if not hasattr(model, "_setup_model") and len(params) > 0: + raise NotImplementedError("loading was executed on a model that has no means to create the policies") + model._setup_model() + + # put state_dicts back in place + for name in params: + attr = recursive_getattr(model, name) + attr.load_state_dict(params[name]) + + # put tensors back in place + if tensors is not None: + for name in tensors: + recursive_setattr(model, name, tensors[name]) + return model @staticmethod @@ -374,8 +370,8 @@ class BaseRLModel(ABC): :param load_path: Where to load the model from :param load_data: Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) - :return: (dict),(dict),(dict) Class parameters, model parameters (state_dict) - and dict of optimizer parameters (dict of state_dict) + :return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict) + and dict of extra tensors """ # Check if file exists if load_path is a string if isinstance(load_path, str): @@ -385,6 +381,12 @@ class BaseRLModel(ABC): else: raise ValueError(f"Error: the file {load_path} could not be found") + # set device to cpu if cuda is not available + if th.cuda.is_available(): + device = th.device('cuda') + else: + device = th.device('cpu') + # Open the zip archive and load data try: with zipfile.ZipFile(load_path, "r") as archive: @@ -393,31 +395,32 @@ class BaseRLModel(ABC): # zip archive, assume they were stored # as None (_save_to_file_zip allows this). data = None - params = None - opt_params = None + tensors = None + params = {} + if "data" in namelist and load_data: # Load class parameters and convert to string json_data = archive.read("data").decode() - data = json_to_data(json_data) + data = json_to_data(json_data, device) - if "params.pth" in namelist: - # Load parameters with build in torch function - with archive.open("params.pth", mode="r") as param_file: - # File has to be seekable, but param_file is not, so load in BytesIO first + if "tensors.pth" in namelist and load_data: + # Load extra tensors + with archive.open('tensors.pth', mode="r") as tensor_file: + # File has to be seekable, but opt_param_file is not, so load in BytesIO first # fixed in python >= 3.7 file_content = io.BytesIO() - file_content.write(param_file.read()) + file_content.write(tensor_file.read()) # go to start of file file_content.seek(0) - params = th.load(file_content) + # load the parameters with the right `map_location` + tensors = th.load(file_content, map_location=device) # check for all other .pth files other_files = [file_name for file_name in namelist if - os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] + os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"] # if there are any other files which end with .pth and aren't "params.pth" # assume that they each are optimizer parameters if len(other_files) > 0: - opt_params = dict() for file_path in other_files: with archive.open(file_path, mode="r") as opt_param_file: # File has to be seekable, but opt_param_file is not, so load in BytesIO first @@ -426,13 +429,27 @@ class BaseRLModel(ABC): file_content.write(opt_param_file.read()) # go to start of file file_content.seek(0) - # save the parameters in dict with file name but trim file ending - opt_params[os.path.splitext(file_path)[0]] = th.load(file_content) + # load the parameters with the right `map_location` + params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device) + + # for backward compatibility + if params.get('params') is not None: + params_copy = {} + for name in params: + if name == 'params': + params_copy['policy'] = params[name] + elif name == 'opt': + params_copy['policy.optimizer'] = params[name] + # Special case for SAC + elif name == 'ent_coef_optimizer': + params_copy[name] = params[name] + else: + params_copy[name + '.optimizer'] = params[name] + params = params_copy except zipfile.BadZipFile: # load_path wasn't a zip file raise ValueError(f"Error: the file {load_path} wasn't a zip-file") - - return data, params, opt_params + return data, params, tensors def set_random_seed(self, seed: Optional[int] = None) -> None: """ @@ -654,15 +671,15 @@ class BaseRLModel(ABC): @staticmethod def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None, - params: Dict[str, Any] = None, opt_params: Dict[str, Any] = None) -> None: + params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None: """ Save model to a zip archive. :param save_path: Where to store the model :param data: Class parameters being stored - :param params: Model parameters being stored expected to be state_dict - :param opt_params: Optimizer parameters being stored expected to contain an entry for every - optimizer with its name and the state_dict + :param params: Model parameters being stored expected to contain an entry for every + state_dict with its name and the state_dict + :param tensors: Extra tensor variables expected to contain name and value of tensors """ # data/params can be None, so do not @@ -683,23 +700,22 @@ class BaseRLModel(ABC): # Do not try to save "None" elements if data is not None: archive.writestr("data", serialized_data) + if tensors is not None: + with archive.open('tensors.pth', mode="w") as tensors_file: + th.save(tensors, tensors_file) if params is not None: - with archive.open('params.pth', mode="w") as param_file: - th.save(params, param_file) - if opt_params is not None: - for file_name, dict_ in opt_params.items(): - with archive.open(file_name + '.pth', mode="w") as opt_param_file: - th.save(dict_, opt_param_file) + for file_name, dict_ in params.items(): + with archive.open(file_name + '.pth', mode="w") as param_file: + th.save(dict_, param_file) - @staticmethod - def excluded_save_params() -> List[str]: + def excluded_save_params(self) -> List[str]: """ Returns the names of the parameters that should be excluded by default when saving the model. :return: ([str]) List of parameters that should be excluded from save """ - return ["env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"] + return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"] def save(self, path: str, exclude: Optional[List[str]] = None, include: Optional[List[str]] = None) -> None: """ @@ -717,18 +733,40 @@ class BaseRLModel(ABC): else: # append standard exclude params to the given params exclude.extend([param for param in self.excluded_save_params() if param not in exclude]) + # do not exclude params if they are specifically included if include is not None: exclude = [param_name for param_name in exclude if param_name not in include] - # remove parameter entries of parameters which are to be excluded + state_dicts_names, tensors_names = self.get_torch_variables() + # any params that are in the save vars must not be saved by data + torch_variables = state_dicts_names + tensors_names + for torch_var in torch_variables: + # we need to get only the name of the top most module as we'll remove that + var_name = torch_var.split('.')[0] + exclude.append(var_name) + + # Remove parameter entries of parameters which are to be excluded for param_name in exclude: if param_name in data: data.pop(param_name, None) - params_to_save = self.get_policy_parameters() - opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) + # Build dict of tensor variables + tensors = None + if tensors_names is not None: + tensors = {} + for name in tensors_names: + attr = recursive_getattr(self, name) + tensors[name] = attr + + # Build dict of state_dicts + params_to_save = {} + for name in state_dicts_names: + attr = recursive_getattr(self, name) + # Retrieve state dict + params_to_save[name] = attr.state_dict() + + self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors) def _eval_policy(self, eval_freq: int, eval_env: int, n_eval_episodes: int, timesteps_since_eval: int, render: bool = False, deterministic: bool = True) -> int: diff --git a/torchy_baselines/common/save_util.py b/torchy_baselines/common/save_util.py index 1b5d801..a9dedbc 100644 --- a/torchy_baselines/common/save_util.py +++ b/torchy_baselines/common/save_util.py @@ -2,15 +2,51 @@ Save util taken from stable_baselines used to serialize data (class parameters) of model classes """ - - import json import base64 -import pickle +import functools +from typing import Dict, Any, Optional, Union + +import torch as th import cloudpickle +import warnings -def is_json_serializable(item): +def recursive_getattr(obj: Any, attr: str, *args) -> Any: + """ + Recursive version of getattr + taken from https://stackoverflow.com/questions/31174295 + + Ex: + > MyObject.sub_object = SubObject(name='test') + > recursive_getattr(MyObject, 'sub_object.name') # return test + :param obj: (Any) + :param attr: (str) Attribute to retrieve + :return: (Any) The attribute + """ + def _getattr(obj: Any, attr: str) -> Any: + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split('.')) + + +def recursive_setattr(obj: Any, attr: str, val: Any) -> None: + """ + Recursive version of setattr + taken from https://stackoverflow.com/questions/31174295 + + Ex: + > MyObject.sub_object = SubObject(name='test') + > recursive_setattr(MyObject, 'sub_object.name', 'hello') + :param obj: (Any) + :param attr: (str) Attribute to set + :param val: (Any) New value of the attribute + """ + pre, _, post = attr.rpartition('.') + return setattr(recursive_getattr(obj, pre) if pre else obj, post, val) + + +def is_json_serializable(item: Any) -> bool: """ Test if an object is serializable into JSON @@ -26,11 +62,11 @@ def is_json_serializable(item): return json_serializable -def data_to_json(data): +def data_to_json(data: Dict[str, Any]) -> str: """ Turn data (class parameters) into a JSON string for storing - :param data: (Dict) Dictionary of class parameters to be + :param data: (Dict[str, Any]) Dictionary of class parameters to be stored. Items that are not JSON serializable will be pickled with Cloudpickle and stored as bytearray in the JSON file @@ -85,12 +121,15 @@ def data_to_json(data): return json_string -def json_to_data(json_string, custom_objects=None): +def json_to_data(json_string: str, + device: Union[th.device, str] = 'cpu', + custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ Turn JSON serialization of class-parameters back into dictionary. :param json_string: (str) JSON serialization of the class-parameters that should be loaded. + :param device: torch.device device to which the data should be mapped if errors occur :param custom_objects: (dict) Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item @@ -118,15 +157,12 @@ def json_to_data(json_string, custom_objects=None): # errors. If so, we can tell bit more information to # user. try: - deserialized_object = cloudpickle.loads( - base64.b64decode(serialization.encode()) - ) - except pickle.UnpicklingError: - raise RuntimeError( - f"Could not deserialize object {data_key}. " + - "Consider using `custom_objects` argument to replace " + - "this object." - ) + base64_object = base64.b64decode(serialization.encode()) + deserialized_object = cloudpickle.loads(base64_object) + except RuntimeError: + warnings.warn(f"Could not deserialize object {data_key}. " + + "Consider using `custom_objects` argument to replace " + + "this object.") return_data[data_key] = deserialized_object else: # Read as it is diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 77a47cf..9e7c48c 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -1,5 +1,6 @@ import os import time +from typing import List, Tuple import gym from gym import spaces @@ -69,6 +70,7 @@ class PPO(BaseRLModel): Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ + def __init__(self, policy, env, learning_rate=3e-4, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None, @@ -310,22 +312,10 @@ class PPO(BaseRLModel): return self - def get_opt_parameters(self): + def get_torch_variables(self) -> Tuple[List[str], List[str]]: """ - Returns a dict of all the optimizers and their parameters + cf base class + """ + state_dicts = ["policy", "policy.optimizer"] - :return: (dict) of optimizer names and their state_dict - """ - return {"opt": self.policy.optimizer.state_dict()} - - def load_parameters(self, load_dict, opt_params): - """ - Load model parameters and optimizer parameters from a dictionary - load_dict should contain all keys from torch.model.state_dict() - This does not load agent's hyper-parameters. - - :param load_dict: (dict) dict of parameters from model.state_dict() - :param opt_params: (dict of dicts) dict of optimizer state_dicts - """ - self.policy.optimizer.load_state_dict(opt_params["opt"]) - self.policy.load_state_dict(load_dict) + return state_dicts, [] diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 1e682cf..7cd92a4 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -1,3 +1,5 @@ +from typing import List, Tuple + import torch as th import torch.nn.functional as F import numpy as np @@ -128,7 +130,7 @@ class SAC(BaseRLModel): # Force conversion to float # this will throw an error if a malformed string (different from 'auto') # is passed - self.ent_coef = th.tensor(float(self.ent_coef)).to(self.device) + self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) self.policy = self.policy_class(self.observation_space, self.action_space, @@ -199,7 +201,7 @@ class SAC(BaseRLModel): ent_coef = th.exp(self.log_ent_coef.detach()) ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() else: - ent_coef = self.ent_coef + ent_coef = self.ent_coef_tensor # Optimize entropy coefficient, also called # entropy temperature or alpha in the paper @@ -294,28 +296,24 @@ class SAC(BaseRLModel): return self - def get_opt_parameters(self): + def excluded_save_params(self) -> List[str]: """ - Returns a dict of all the optimizers and their parameters + Returns the names of the parameters that should be excluded by default + when saving the model. - :return: (Dict) of optimizer names and their state_dict + :return: (List[str]) List of parameters that should be excluded from save """ - opt_dict = {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()} + # Exclude aliases + return super(SAC, self).excluded_save_params() + ["actor", "critic", "critic_target"] + + def get_torch_variables(self) -> Tuple[List[str], List[str]]: + """ + cf base class + """ + state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] + saved_tensors = ['log_ent_coef'] if self.ent_coef_optimizer is not None: - opt_dict.update({"ent_coef_optimizer": self.ent_coef_optimizer.state_dict()}) - return opt_dict - - def load_parameters(self, load_dict, opt_params): - """ - Load model parameters and optimizer parameters from a dictionary - load_dict should contain all keys from torch.model.state_dict() - This does not load agent's hyper-parameters. - - :param load_dict: (dict) dict of parameters from model.state_dict() - :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class - """ - self.actor.optimizer.load_state_dict(opt_params["actor"]) - self.critic.optimizer.load_state_dict(opt_params["critic"]) - if "ent_coef_optimizer" in opt_params: - self.ent_coef_optimizer.load_state_dict(opt_params["ent_coef_optimizer"]) - self.policy.load_state_dict(load_dict) + state_dicts.append('ent_coef_optimizer') + else: + saved_tensors.append('ent_coef_tensor') + return state_dicts, saved_tensors diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 32cc1b2..5729ebf 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -1,3 +1,5 @@ +from typing import List, Tuple + import torch as th import torch.nn.functional as F import numpy as np @@ -296,23 +298,19 @@ class TD3(BaseRLModel): return self - def get_opt_parameters(self): + def excluded_save_params(self) -> List[str]: """ - Returns a dict of all the optimizers and their parameters + Returns the names of the parameters that should be excluded by default + when saving the model. - :return: (Dict) of optimizer names and their state_dict + :return: (List[str]) List of parameters that should be excluded from save """ - return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()} + # Exclude aliases + return super(TD3, self).excluded_save_params() + ["actor", "critic", "vf_net", "actor_target", "critic_target"] - def load_parameters(self, load_dict, opt_params): + def get_torch_variables(self) -> Tuple[List[str], List[str]]: """ - Load model parameters and optimizer parameters from a dictionary - load_dict should contain all keys from torch.model.state_dict() - This does not load agent's hyper-parameters. - - :param load_dict: (dict) dict of parameters from model.state_dict() - :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class + cf base class """ - self.actor.optimizer.load_state_dict(opt_params["actor"]) - self.critic.optimizer.load_state_dict(opt_params["critic"]) - self.policy.load_state_dict(load_dict) + state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] + return state_dicts, []