diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index a195580..680ac17 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -174,6 +174,13 @@ class BaseRLModel(object): """ pass + def get_parameter_list(self): + """ + Returns policy and optimizer parameters as a tuple + :return: (dict,dict) policy_parameters, opt_parameters + """ + return self.get_policy_parameters(),self.get_opt_parameters() + def get_policy_parameters(self): """ Get current model policy parameters as dictionary of variable name -> tensors.