Fix saving on GPU - Loading on CPU (#45)

* removed policy from save, changed th.loads to map to device

* found hack: catch pickle exception and trying th.load with mapping instead, otherwise raise exception with more information -> loading cuda on cpu raises exception -> leads to th.load with map being called

* deleted todo

* updated changelog

* start of saving refactor

* first working c

* all tests pass, save refactored

* - backwards compatibilty not always
- make pytest all passing
- make typing all passing

* Fixes and simplify the save method

* Remove unused param

* Fix backward compat

* Fix docstring
This commit is contained in:
Dormann, Noah 2020-01-31 13:06:55 +01:00 committed by Raffin, Antonin
parent cc3b023533
commit 1f0dd60b97
7 changed files with 211 additions and 171 deletions

View file

@ -16,6 +16,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fix loading model on CPU that were trained on GPU
Deprecations:
^^^^^^^^^^^^^

View file

@ -16,13 +16,12 @@ MODEL_LIST = [
SAC,
]
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_save_load(model_class):
"""
Test if 'save' and 'load' saves and loads model correctly
and if 'load_parameters' and 'get_policy_parameters' work correctly
''warning does not test function of optimizer parameter load
:param model_class: (BaseRLModel) A RL model
@ -38,16 +37,15 @@ def test_save_load(model_class):
observations = np.squeeze(observations)
# Get dictionary of current parameters
params = deepcopy(model.get_policy_parameters())
opt_params = deepcopy(model.get_opt_parameters())
params = deepcopy(model.policy.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Update model parameters with the new random values
model.load_parameters(random_params, opt_params)
model.policy.load_state_dict(random_params)
new_params = model.get_policy_parameters()
new_params = model.policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
@ -63,33 +61,14 @@ def test_save_load(model_class):
model = model_class.load("test_save", env=env)
# check if params are still the same after load
new_params = model.get_policy_parameters()
new_params = model.policy.state_dict()
# Check that all params are the same as before save load procedure now
for k in params:
assert th.allclose(params[k], new_params[k]), "Model parameters not the same after save and load."
# check if optimizer params are still the same after load
new_opt_params = model.get_opt_parameters()
# check if keys are the same
assert opt_params.keys() == new_opt_params.keys()
# check if values are the same: only tested for Adam and RMSProp so far
# comparing states not implemented so far. hashes of state_entries are not the same for equal tensors
# comparing every sub_entry does not work because of bool value of Tensor with more than one value is ambiguous
# so far only comparing param_lists
for optimizer, opt_state in opt_params.items():
for param_group_idx, param_group in enumerate(opt_state['param_groups']):
for param_key, param_value in param_group.items():
# don't know how to handle params correctly, therefore only check if we have the same amount
if param_key == 'params':
assert len(param_value) == len(
new_opt_params[optimizer]['param_groups'][param_group_idx][param_key])
else:
assert param_value == new_opt_params[optimizer]['param_groups'][param_group_idx][param_key]
for key in params:
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
# for i in range(len(selected_actionsselected_actions)):
assert np.allclose(selected_actions, new_selected_actions)
# check if learn still works

View file

@ -17,7 +17,7 @@ from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, upda
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, sync_envs_normalization
from torchy_baselines.common.monitor import Monitor
from torchy_baselines.common.evaluation import evaluate_policy
from torchy_baselines.common.save_util import data_to_json, json_to_data
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
# TODO: define aliases, ex GymEnv = Union[gym.Env, VecEnv]
if typing.TYPE_CHECKING:
@ -49,11 +49,12 @@ class BaseRLModel(ABC):
:param sde_sample_freq: Sample a new noise matrix every n steps when using SDE
Default: -1 (only sample at the beginning of the rollout)
"""
def __init__(self,
policy: Type[BasePolicy],
env: Union[gym.Env, VecEnv, str],
policy_base: Type[BasePolicy],
policy_kwargs : Dict[str, Any] = None,
policy_kwargs: Dict[str, Any] = None,
verbose: int = 0,
device: Union[th.device, str] = 'auto',
support_multi_env: bool = False,
@ -129,6 +130,14 @@ class BaseRLModel(ABC):
raise ValueError("Error: the model does not support multiple envs requires a single vectorized"
" environment.")
@abstractmethod
def _setup_model(self) -> None:
"""
Setup model so state_dict can be loaded
"""
raise NotImplementedError()
def _get_eval_env(self, eval_env: Union[gym.Env, VecEnv, None]) -> Union[gym.Env, VecEnv, None]:
"""
Return the environment that will be used for evaluation.
@ -252,29 +261,18 @@ class BaseRLModel(ABC):
self.n_envs = env.num_envs
self.env = env
def get_parameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
Returns policy and optimizer parameters as a tuple
Get the name of the torch variable that will be saved.
`th.save` and `th.load` will be used with the right device
instead of the default pickling strategy.
:return: policy_parameters, opt_parameters
:return: (Tuple[List[str], List[str]])
name of the variables with state dicts to save, name of additional torch tensors,
"""
return self.get_policy_parameters(), self.get_opt_parameters()
state_dicts = ["policy"]
def get_policy_parameters(self) -> Dict[str, Any]:
"""
Get current model policy parameters as dictionary of variable name -> tensors.
:return: Dictionary of variable name -> tensor of model's policy parameters.
"""
return self.policy.state_dict()
@abstractmethod
def get_opt_parameters(self)-> Dict[str, Any]:
"""
Get current model optimizer parameters as dictionary of variable names -> tensors
:return: (dict) Dictionary of variable name -> tensor of model's optimizer parameters
"""
raise NotImplementedError()
return state_dicts, []
@abstractmethod
def learn(self, total_timesteps: int,
@ -316,21 +314,6 @@ class BaseRLModel(ABC):
"""
raise NotImplementedError()
def load_parameters(self, load_dict: Dict[str, Any], opt_params: Dict[str, Any]) -> None:
"""
Load model parameters from a dictionary
load_dict should contain all keys from torch.model.state_dict()
If opt_params are given this does also load agent's optimizer-parameters,
but can only be handled in child classes.
:param load_dict: dict of parameters from model.state_dict()
:param opt_params: dict of optimizer state_dicts should be handled in child_class
"""
if opt_params is not None:
raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class")
self.policy.load_state_dict(load_dict)
@classmethod
def load(cls, load_path: str, env: Union[gym.Env, VecEnv, None] = None, **kwargs):
"""
@ -341,7 +324,7 @@ class BaseRLModel(ABC):
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param kwargs: extra arguments to change the model when loading
"""
data, params, opt_params = cls._load_from_file(load_path)
data, params, tensors = cls._load_from_file(load_path)
if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs."
@ -359,12 +342,25 @@ class BaseRLModel(ABC):
# first create model, but only setup if a env was given
# noinspection PyArgumentList
model = cls(policy=data["policy_class"], env=env, _init_setup_model=env is not None)
model = cls(policy=data["policy_class"], env=env, device='auto', _init_setup_model=env is not None)
# load parameters
model.__dict__.update(data)
model.__dict__.update(kwargs)
model.load_parameters(params, opt_params)
if not hasattr(model, "_setup_model") and len(params) > 0:
raise NotImplementedError("loading was executed on a model that has no means to create the policies")
model._setup_model()
# put state_dicts back in place
for name in params:
attr = recursive_getattr(model, name)
attr.load_state_dict(params[name])
# put tensors back in place
if tensors is not None:
for name in tensors:
recursive_setattr(model, name, tensors[name])
return model
@staticmethod
@ -374,8 +370,8 @@ class BaseRLModel(ABC):
:param load_path: Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:return: (dict),(dict),(dict) Class parameters, model parameters (state_dict)
and dict of optimizer parameters (dict of state_dict)
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
and dict of extra tensors
"""
# Check if file exists if load_path is a string
if isinstance(load_path, str):
@ -385,6 +381,12 @@ class BaseRLModel(ABC):
else:
raise ValueError(f"Error: the file {load_path} could not be found")
# set device to cpu if cuda is not available
if th.cuda.is_available():
device = th.device('cuda')
else:
device = th.device('cpu')
# Open the zip archive and load data
try:
with zipfile.ZipFile(load_path, "r") as archive:
@ -393,31 +395,32 @@ class BaseRLModel(ABC):
# zip archive, assume they were stored
# as None (_save_to_file_zip allows this).
data = None
params = None
opt_params = None
tensors = None
params = {}
if "data" in namelist and load_data:
# Load class parameters and convert to string
json_data = archive.read("data").decode()
data = json_to_data(json_data)
data = json_to_data(json_data, device)
if "params.pth" in namelist:
# Load parameters with build in torch function
with archive.open("params.pth", mode="r") as param_file:
# File has to be seekable, but param_file is not, so load in BytesIO first
if "tensors.pth" in namelist and load_data:
# Load extra tensors
with archive.open('tensors.pth', mode="r") as tensor_file:
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(param_file.read())
file_content.write(tensor_file.read())
# go to start of file
file_content.seek(0)
params = th.load(file_content)
# load the parameters with the right `map_location`
tensors = th.load(file_content, map_location=device)
# check for all other .pth files
other_files = [file_name for file_name in namelist if
os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"]
os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"]
# if there are any other files which end with .pth and aren't "params.pth"
# assume that they each are optimizer parameters
if len(other_files) > 0:
opt_params = dict()
for file_path in other_files:
with archive.open(file_path, mode="r") as opt_param_file:
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
@ -426,13 +429,27 @@ class BaseRLModel(ABC):
file_content.write(opt_param_file.read())
# go to start of file
file_content.seek(0)
# save the parameters in dict with file name but trim file ending
opt_params[os.path.splitext(file_path)[0]] = th.load(file_content)
# load the parameters with the right `map_location`
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
# for backward compatibility
if params.get('params') is not None:
params_copy = {}
for name in params:
if name == 'params':
params_copy['policy'] = params[name]
elif name == 'opt':
params_copy['policy.optimizer'] = params[name]
# Special case for SAC
elif name == 'ent_coef_optimizer':
params_copy[name] = params[name]
else:
params_copy[name + '.optimizer'] = params[name]
params = params_copy
except zipfile.BadZipFile:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
return data, params, opt_params
return data, params, tensors
def set_random_seed(self, seed: Optional[int] = None) -> None:
"""
@ -654,15 +671,15 @@ class BaseRLModel(ABC):
@staticmethod
def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None,
params: Dict[str, Any] = None, opt_params: Dict[str, Any] = None) -> None:
params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None:
"""
Save model to a zip archive.
:param save_path: Where to store the model
:param data: Class parameters being stored
:param params: Model parameters being stored expected to be state_dict
:param opt_params: Optimizer parameters being stored expected to contain an entry for every
optimizer with its name and the state_dict
:param params: Model parameters being stored expected to contain an entry for every
state_dict with its name and the state_dict
:param tensors: Extra tensor variables expected to contain name and value of tensors
"""
# data/params can be None, so do not
@ -683,23 +700,22 @@ class BaseRLModel(ABC):
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if tensors is not None:
with archive.open('tensors.pth', mode="w") as tensors_file:
th.save(tensors, tensors_file)
if params is not None:
with archive.open('params.pth', mode="w") as param_file:
th.save(params, param_file)
if opt_params is not None:
for file_name, dict_ in opt_params.items():
with archive.open(file_name + '.pth', mode="w") as opt_param_file:
th.save(dict_, opt_param_file)
for file_name, dict_ in params.items():
with archive.open(file_name + '.pth', mode="w") as param_file:
th.save(dict_, param_file)
@staticmethod
def excluded_save_params() -> List[str]:
def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.
:return: ([str]) List of parameters that should be excluded from save
"""
return ["env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]
def save(self, path: str, exclude: Optional[List[str]] = None, include: Optional[List[str]] = None) -> None:
"""
@ -717,18 +733,40 @@ class BaseRLModel(ABC):
else:
# append standard exclude params to the given params
exclude.extend([param for param in self.excluded_save_params() if param not in exclude])
# do not exclude params if they are specifically included
if include is not None:
exclude = [param_name for param_name in exclude if param_name not in include]
# remove parameter entries of parameters which are to be excluded
state_dicts_names, tensors_names = self.get_torch_variables()
# any params that are in the save vars must not be saved by data
torch_variables = state_dicts_names + tensors_names
for torch_var in torch_variables:
# we need to get only the name of the top most module as we'll remove that
var_name = torch_var.split('.')[0]
exclude.append(var_name)
# Remove parameter entries of parameters which are to be excluded
for param_name in exclude:
if param_name in data:
data.pop(param_name, None)
params_to_save = self.get_policy_parameters()
opt_params_to_save = self.get_opt_parameters()
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)
# Build dict of tensor variables
tensors = None
if tensors_names is not None:
tensors = {}
for name in tensors_names:
attr = recursive_getattr(self, name)
tensors[name] = attr
# Build dict of state_dicts
params_to_save = {}
for name in state_dicts_names:
attr = recursive_getattr(self, name)
# Retrieve state dict
params_to_save[name] = attr.state_dict()
self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)
def _eval_policy(self, eval_freq: int, eval_env: int, n_eval_episodes: int,
timesteps_since_eval: int, render: bool = False, deterministic: bool = True) -> int:

View file

@ -2,15 +2,51 @@
Save util taken from stable_baselines
used to serialize data (class parameters) of model classes
"""
import json
import base64
import pickle
import functools
from typing import Dict, Any, Optional, Union
import torch as th
import cloudpickle
import warnings
def is_json_serializable(item):
def recursive_getattr(obj: Any, attr: str, *args) -> Any:
"""
Recursive version of getattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_getattr(MyObject, 'sub_object.name') # return test
:param obj: (Any)
:param attr: (str) Attribute to retrieve
:return: (Any) The attribute
"""
def _getattr(obj: Any, attr: str) -> Any:
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
"""
Recursive version of setattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_setattr(MyObject, 'sub_object.name', 'hello')
:param obj: (Any)
:param attr: (str) Attribute to set
:param val: (Any) New value of the attribute
"""
pre, _, post = attr.rpartition('.')
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
def is_json_serializable(item: Any) -> bool:
"""
Test if an object is serializable into JSON
@ -26,11 +62,11 @@ def is_json_serializable(item):
return json_serializable
def data_to_json(data):
def data_to_json(data: Dict[str, Any]) -> str:
"""
Turn data (class parameters) into a JSON string for storing
:param data: (Dict) Dictionary of class parameters to be
:param data: (Dict[str, Any]) Dictionary of class parameters to be
stored. Items that are not JSON serializable will be
pickled with Cloudpickle and stored as bytearray in
the JSON file
@ -85,12 +121,15 @@ def data_to_json(data):
return json_string
def json_to_data(json_string, custom_objects=None):
def json_to_data(json_string: str,
device: Union[th.device, str] = 'cpu',
custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Turn JSON serialization of class-parameters back into dictionary.
:param json_string: (str) JSON serialization of the class-parameters
that should be loaded.
:param device: torch.device device to which the data should be mapped if errors occur
:param custom_objects: (dict) Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
@ -118,15 +157,12 @@ def json_to_data(json_string, custom_objects=None):
# errors. If so, we can tell bit more information to
# user.
try:
deserialized_object = cloudpickle.loads(
base64.b64decode(serialization.encode())
)
except pickle.UnpicklingError:
raise RuntimeError(
f"Could not deserialize object {data_key}. " +
"Consider using `custom_objects` argument to replace " +
"this object."
)
base64_object = base64.b64decode(serialization.encode())
deserialized_object = cloudpickle.loads(base64_object)
except RuntimeError:
warnings.warn(f"Could not deserialize object {data_key}. " +
"Consider using `custom_objects` argument to replace " +
"this object.")
return_data[data_key] = deserialized_object
else:
# Read as it is

View file

@ -1,5 +1,6 @@
import os
import time
from typing import List, Tuple
import gym
from gym import spaces
@ -69,6 +70,7 @@ class PPO(BaseRLModel):
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
def __init__(self, policy, env, learning_rate=3e-4,
n_steps=2048, batch_size=64, n_epochs=10,
gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None,
@ -310,22 +312,10 @@ class PPO(BaseRLModel):
return self
def get_opt_parameters(self):
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
Returns a dict of all the optimizers and their parameters
cf base class
"""
state_dicts = ["policy", "policy.optimizer"]
:return: (dict) of optimizer names and their state_dict
"""
return {"opt": self.policy.optimizer.state_dict()}
def load_parameters(self, load_dict, opt_params):
"""
Load model parameters and optimizer parameters from a dictionary
load_dict should contain all keys from torch.model.state_dict()
This does not load agent's hyper-parameters.
:param load_dict: (dict) dict of parameters from model.state_dict()
:param opt_params: (dict of dicts) dict of optimizer state_dicts
"""
self.policy.optimizer.load_state_dict(opt_params["opt"])
self.policy.load_state_dict(load_dict)
return state_dicts, []

View file

@ -1,3 +1,5 @@
from typing import List, Tuple
import torch as th
import torch.nn.functional as F
import numpy as np
@ -128,7 +130,7 @@ class SAC(BaseRLModel):
# Force conversion to float
# this will throw an error if a malformed string (different from 'auto')
# is passed
self.ent_coef = th.tensor(float(self.ent_coef)).to(self.device)
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
self.policy = self.policy_class(self.observation_space, self.action_space,
@ -199,7 +201,7 @@ class SAC(BaseRLModel):
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
else:
ent_coef = self.ent_coef
ent_coef = self.ent_coef_tensor
# Optimize entropy coefficient, also called
# entropy temperature or alpha in the paper
@ -294,28 +296,24 @@ class SAC(BaseRLModel):
return self
def get_opt_parameters(self):
def excluded_save_params(self) -> List[str]:
"""
Returns a dict of all the optimizers and their parameters
Returns the names of the parameters that should be excluded by default
when saving the model.
:return: (Dict) of optimizer names and their state_dict
:return: (List[str]) List of parameters that should be excluded from save
"""
opt_dict = {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()}
# Exclude aliases
return super(SAC, self).excluded_save_params() + ["actor", "critic", "critic_target"]
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
saved_tensors = ['log_ent_coef']
if self.ent_coef_optimizer is not None:
opt_dict.update({"ent_coef_optimizer": self.ent_coef_optimizer.state_dict()})
return opt_dict
def load_parameters(self, load_dict, opt_params):
"""
Load model parameters and optimizer parameters from a dictionary
load_dict should contain all keys from torch.model.state_dict()
This does not load agent's hyper-parameters.
:param load_dict: (dict) dict of parameters from model.state_dict()
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
"""
self.actor.optimizer.load_state_dict(opt_params["actor"])
self.critic.optimizer.load_state_dict(opt_params["critic"])
if "ent_coef_optimizer" in opt_params:
self.ent_coef_optimizer.load_state_dict(opt_params["ent_coef_optimizer"])
self.policy.load_state_dict(load_dict)
state_dicts.append('ent_coef_optimizer')
else:
saved_tensors.append('ent_coef_tensor')
return state_dicts, saved_tensors

View file

@ -1,3 +1,5 @@
from typing import List, Tuple
import torch as th
import torch.nn.functional as F
import numpy as np
@ -296,23 +298,19 @@ class TD3(BaseRLModel):
return self
def get_opt_parameters(self):
def excluded_save_params(self) -> List[str]:
"""
Returns a dict of all the optimizers and their parameters
Returns the names of the parameters that should be excluded by default
when saving the model.
:return: (Dict) of optimizer names and their state_dict
:return: (List[str]) List of parameters that should be excluded from save
"""
return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()}
# Exclude aliases
return super(TD3, self).excluded_save_params() + ["actor", "critic", "vf_net", "actor_target", "critic_target"]
def load_parameters(self, load_dict, opt_params):
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
Load model parameters and optimizer parameters from a dictionary
load_dict should contain all keys from torch.model.state_dict()
This does not load agent's hyper-parameters.
:param load_dict: (dict) dict of parameters from model.state_dict()
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
cf base class
"""
self.actor.optimizer.load_state_dict(opt_params["actor"])
self.critic.optimizer.load_state_dict(opt_params["critic"])
self.policy.load_state_dict(load_dict)
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
return state_dicts, []