Don't save replay_buffer by default

This commit is contained in:
Noah Dormann 2019-11-21 17:27:46 +01:00
parent cfb822aa91
commit 4f8f936451

View file

@ -270,15 +270,6 @@ class BaseRLModel(object):
raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class")
self.policy.load_state_dict(load_dict)
@abstractmethod
def save(self, save_path):
"""
Save the current parameters to file
:param save_path: (str or file-like object) the save location
"""
raise NotImplementedError()
@classmethod
def load(cls, load_path, env=None, **kwargs):
"""
@ -512,15 +503,27 @@ class BaseRLModel(object):
with file_.open(file_name + '.pth', mode="w") as opt_param_file:
th.save(dict, opt_param_file)
def save(self, path):
"""
saves all the params from init and pytorch params in a file for continous learning
saves all the params from init and pytorch params in a file for continuous learning
:param path: path to the file where the data should be safed
:param path: path to the file where the data should be saved
:return:
"""
data = self.__dict__
data.pop("replay_buffer")
params_to_save = self.get_policy_parameters()
opt_params_to_save = self.get_opt_parameters()
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)
def save_with_replay_buffer(self, path):
"""
saves all the params from init and pytorch params in a file for continuous learning
:param path: path to the file where the data should be saved
:return:
"""
data = self.__dict__
params_to_save = self.get_policy_parameters()
opt_params_to_save = self.get_opt_parameters()
self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save)
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)