Added option to explicitly specify excluded parameters

This commit is contained in:
Noah Dormann 2019-11-28 15:42:53 +01:00
parent e95858784a
commit ee6f938ddc

View file

@ -517,17 +517,19 @@ class BaseRLModel(object):
"""
return ["replay_buffer"]
def save(self, path, include=None):
def save(self, path, exclude=None, include=None):
"""
saves all the params from init and pytorch params in a file for continuous learning
:param path: (str) path to the file where the data should be saved
:param exclude: (list) name of parameters that should be excluded, use standard exclude params if None
:param include: (list) name of parameters that might be excluded but should be included anyway
:return:
"""
data = self.__dict__
# get list of params to be excluded
exclude = self.excluded_save_params()
# use standard list of excluded parameters if none given
if exclude is None:
exclude = self.excluded_save_params()
# do not exclude params if they are specifically included
if include is not None:
exclude = [param_name for param_name in exclude if param_name not in include]