mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-28 22:56:53 +00:00
Added option to explicitly specify excluded parameters
This commit is contained in:
parent
e95858784a
commit
ee6f938ddc
1 changed files with 5 additions and 3 deletions
|
|
@ -517,17 +517,19 @@ class BaseRLModel(object):
|
|||
"""
|
||||
return ["replay_buffer"]
|
||||
|
||||
def save(self, path, include=None):
|
||||
def save(self, path, exclude=None, include=None):
|
||||
"""
|
||||
saves all the params from init and pytorch params in a file for continuous learning
|
||||
|
||||
:param path: (str) path to the file where the data should be saved
|
||||
:param exclude: (list) name of parameters that should be excluded, use standard exclude params if None
|
||||
:param include: (list) name of parameters that might be excluded but should be included anyway
|
||||
:return:
|
||||
"""
|
||||
data = self.__dict__
|
||||
# get list of params to be excluded
|
||||
exclude = self.excluded_save_params()
|
||||
# use standard list of excluded parameters if none given
|
||||
if exclude is None:
|
||||
exclude = self.excluded_save_params()
|
||||
# do not exclude params if they are specifically included
|
||||
if include is not None:
|
||||
exclude = [param_name for param_name in exclude if param_name not in include]
|
||||
|
|
|
|||
Loading…
Reference in a new issue