From 4b1bab7f858befa36aef0231b67e62c8f2f7d221 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 09:11:30 +0100 Subject: [PATCH] implemented set_env method --- torchy_baselines/common/base_class.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 09fe798..384ad27 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -169,10 +169,20 @@ class BaseRLModel(object): def set_env(self, env): """ Checks the validity of the environment, and if it is coherent, set it as the current environment. + checked parameters: + - observation_space + - action_space :param env: (Gym Environment) The environment for learning a policy """ - pass + + if self.observation_space != env.observation_space: + raise ValueError("The given environment has a observation_space that doesn't fit the current model") + + if self.action_space != env.action_space: + raise ValueError("The given environment has a action_space that doesn't fit the current model") + # if all fits save new env + self.env = env def get_parameter_list(self): """