diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index b2c6056..57c27b0 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -270,15 +270,6 @@ class BaseRLModel(object): raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class") self.policy.load_state_dict(load_dict) - @abstractmethod - def save(self, save_path): - """ - Save the current parameters to file - - :param save_path: (str or file-like object) the save location - """ - raise NotImplementedError() - @classmethod def load(cls, load_path, env=None, **kwargs): """ @@ -512,15 +503,27 @@ class BaseRLModel(object): with file_.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) - def save(self, path): """ - saves all the params from init and pytorch params in a file for continous learning + saves all the params from init and pytorch params in a file for continuous learning - :param path: path to the file where the data should be safed + :param path: path to the file where the data should be saved + :return: + """ + data = self.__dict__ + data.pop("replay_buffer") + params_to_save = self.get_policy_parameters() + opt_params_to_save = self.get_opt_parameters() + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) + + def save_with_replay_buffer(self, path): + """ + saves all the params from init and pytorch params in a file for continuous learning + + :param path: path to the file where the data should be saved :return: """ data = self.__dict__ params_to_save = self.get_policy_parameters() opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)