From 7c8d375bcb3d30bac3861a69680ffe0da986b4f0 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 08:50:11 +0100 Subject: [PATCH] added get_parameter_list function --- torchy_baselines/common/base_class.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index a195580..680ac17 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -174,6 +174,13 @@ class BaseRLModel(object): """ pass + def get_parameter_list(self): + """ + Returns policy and optimizer parameters as a tuple + :return: (dict,dict) policy_parameters, opt_parameters + """ + return self.get_policy_parameters(),self.get_opt_parameters() + def get_policy_parameters(self): """ Get current model policy parameters as dictionary of variable name -> tensors.