mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-26 22:45:15 +00:00
Don't save replay_buffer by default
This commit is contained in:
parent
cfb822aa91
commit
4f8f936451
1 changed files with 16 additions and 13 deletions
|
|
@ -270,15 +270,6 @@ class BaseRLModel(object):
|
|||
raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class")
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
||||
@abstractmethod
|
||||
def save(self, save_path):
|
||||
"""
|
||||
Save the current parameters to file
|
||||
|
||||
:param save_path: (str or file-like object) the save location
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_path, env=None, **kwargs):
|
||||
"""
|
||||
|
|
@ -512,15 +503,27 @@ class BaseRLModel(object):
|
|||
with file_.open(file_name + '.pth', mode="w") as opt_param_file:
|
||||
th.save(dict, opt_param_file)
|
||||
|
||||
|
||||
def save(self, path):
|
||||
"""
|
||||
saves all the params from init and pytorch params in a file for continous learning
|
||||
saves all the params from init and pytorch params in a file for continuous learning
|
||||
|
||||
:param path: path to the file where the data should be safed
|
||||
:param path: path to the file where the data should be saved
|
||||
:return:
|
||||
"""
|
||||
data = self.__dict__
|
||||
data.pop("replay_buffer")
|
||||
params_to_save = self.get_policy_parameters()
|
||||
opt_params_to_save = self.get_opt_parameters()
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)
|
||||
|
||||
def save_with_replay_buffer(self, path):
|
||||
"""
|
||||
saves all the params from init and pytorch params in a file for continuous learning
|
||||
|
||||
:param path: path to the file where the data should be saved
|
||||
:return:
|
||||
"""
|
||||
data = self.__dict__
|
||||
params_to_save = self.get_policy_parameters()
|
||||
opt_params_to_save = self.get_opt_parameters()
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save)
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)
|
||||
|
|
|
|||
Loading…
Reference in a new issue