From ee6f938ddc0da2e3cec56fcea1d9d3dba794fb74 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 15:42:53 +0100 Subject: [PATCH] Added option to explicitly specify excluded parameters --- torchy_baselines/common/base_class.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 8c363c9..ecfb31e 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -517,17 +517,19 @@ class BaseRLModel(object): """ return ["replay_buffer"] - def save(self, path, include=None): + def save(self, path, exclude=None, include=None): """ saves all the params from init and pytorch params in a file for continuous learning :param path: (str) path to the file where the data should be saved + :param exclude: (list) name of parameters that should be excluded, use standard exclude params if None :param include: (list) name of parameters that might be excluded but should be included anyway :return: """ data = self.__dict__ - # get list of params to be excluded - exclude = self.excluded_save_params() + # use standard list of excluded parameters if none given + if exclude is None: + exclude = self.excluded_save_params() # do not exclude params if they are specifically included if include is not None: exclude = [param_name for param_name in exclude if param_name not in include]