diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 09fe798..384ad27 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -169,10 +169,20 @@ class BaseRLModel(object): def set_env(self, env): """ Checks the validity of the environment, and if it is coherent, set it as the current environment. + checked parameters: + - observation_space + - action_space :param env: (Gym Environment) The environment for learning a policy """ - pass + + if self.observation_space != env.observation_space: + raise ValueError("The given environment has a observation_space that doesn't fit the current model") + + if self.action_space != env.action_space: + raise ValueError("The given environment has a action_space that doesn't fit the current model") + # if all fits save new env + self.env = env def get_parameter_list(self): """