From 88d4f44d554f31b4d668a53b1ecc47ada4a0b1be Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 13:59:07 +0100 Subject: [PATCH] added set_env test and set_env wrapping --- tests/test_save_load.py | 25 +++++++++++++++++++++++-- torchy_baselines/common/base_class.py | 17 +++++++++++------ 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index eef023d..3a2cb42 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -7,7 +7,7 @@ import torch as th from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 from torchy_baselines.common.vec_env import DummyVecEnv -from torchy_baselines.common.identity_env import IdentityEnvBox +from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv MODEL_LIST = [ CEMRL, @@ -101,7 +101,28 @@ def test_save_load(model_class): @pytest.mark.parametrize("model_class", MODEL_LIST) def test_set_env(model_class): - pass + """ + Test if set_env function does work correct + :param model_class: (BaseRLModel) A RL model + """ + env = DummyVecEnv([lambda: IdentityEnvBox(10)]) + env2 = DummyVecEnv([lambda: IdentityEnvBox(10)]) + env3 = IdentityEnvBox(10) + + # create model + model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), create_eval_env=True) + # learn + model.learn(total_timesteps=1000, eval_freq=500) + + # change env + model.set_env(env2) + # learn again + model.learn(total_timesteps=1000, eval_freq=500) + + # change env test wrapping + model.set_env(env3) + # learn again + model.learn(total_timesteps=1000, eval_freq=500) @pytest.mark.parametrize("model_class", MODEL_LIST) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index bb98a27..4a08917 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -185,6 +185,7 @@ class BaseRLModel(object): def set_env(self, env): """ Checks the validity of the environment, and if it is coherent, set it as the current environment. + Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space @@ -193,13 +194,16 @@ class BaseRLModel(object): """ if self.check_env(env, self.observation_space, self.action_space) is False: raise ValueError("Given environment is not compatible with model") - # if all fits save new env + # it must be coherent now + # if it is not a VecEnv, make it a VecEnv + if not isinstance(env, VecEnv): + if self.verbose >= 1: + print("Wrapping the env in a DummyVecEnv.") + env = DummyVecEnv([lambda: env]) + self.n_envs = env.num_envs self.env = env - # and update observation and action space - self.observation_space = env.observation_space - self.action_space = env.action_space - def get_parameter_list(self): + def get_parameters(self): """ Returns policy and optimizer parameters as a tuple :return: (dict,dict) policy_parameters, opt_parameters @@ -540,7 +544,8 @@ class BaseRLModel(object): with archive.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) - def excluded_save_params(self): + @staticmethod + def excluded_save_params(): """ returns the names of the parameters that should be excluded from save :return: ([str]) List of parameters that should be excluded from save