From 4f8f9364516d2fffc7616fe42434e0e2d0c1402b Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 17:27:46 +0100 Subject: [PATCH] Don't save replay_buffer by default --- torchy_baselines/common/base_class.py | 29 +++++++++++++++------------ 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index b2c6056..57c27b0 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -270,15 +270,6 @@ class BaseRLModel(object): raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class") self.policy.load_state_dict(load_dict) - @abstractmethod - def save(self, save_path): - """ - Save the current parameters to file - - :param save_path: (str or file-like object) the save location - """ - raise NotImplementedError() - @classmethod def load(cls, load_path, env=None, **kwargs): """ @@ -512,15 +503,27 @@ class BaseRLModel(object): with file_.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) - def save(self, path): """ - saves all the params from init and pytorch params in a file for continous learning + saves all the params from init and pytorch params in a file for continuous learning - :param path: path to the file where the data should be safed + :param path: path to the file where the data should be saved + :return: + """ + data = self.__dict__ + data.pop("replay_buffer") + params_to_save = self.get_policy_parameters() + opt_params_to_save = self.get_opt_parameters() + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) + + def save_with_replay_buffer(self, path): + """ + saves all the params from init and pytorch params in a file for continuous learning + + :param path: path to the file where the data should be saved :return: """ data = self.__dict__ params_to_save = self.get_policy_parameters() opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)