added set_env test and set_env wrapping

This commit is contained in:
Noah Dormann 2019-12-05 13:59:07 +01:00
parent cf1d7118a5
commit 88d4f44d55
2 changed files with 34 additions and 8 deletions

View file

@ -7,7 +7,7 @@ import torch as th
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
from torchy_baselines.common.vec_env import DummyVecEnv
from torchy_baselines.common.identity_env import IdentityEnvBox
from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv
MODEL_LIST = [
CEMRL,
@ -101,7 +101,28 @@ def test_save_load(model_class):
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_set_env(model_class):
pass
"""
Test if set_env function does work correct
:param model_class: (BaseRLModel) A RL model
"""
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
env2 = DummyVecEnv([lambda: IdentityEnvBox(10)])
env3 = IdentityEnvBox(10)
# create model
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), create_eval_env=True)
# learn
model.learn(total_timesteps=1000, eval_freq=500)
# change env
model.set_env(env2)
# learn again
model.learn(total_timesteps=1000, eval_freq=500)
# change env test wrapping
model.set_env(env3)
# learn again
model.learn(total_timesteps=1000, eval_freq=500)
@pytest.mark.parametrize("model_class", MODEL_LIST)

View file

@ -185,6 +185,7 @@ class BaseRLModel(object):
def set_env(self, env):
"""
Checks the validity of the environment, and if it is coherent, set it as the current environment.
Furthermore wrap any non vectorized env into a vectorized
checked parameters:
- observation_space
- action_space
@ -193,13 +194,16 @@ class BaseRLModel(object):
"""
if self.check_env(env, self.observation_space, self.action_space) is False:
raise ValueError("Given environment is not compatible with model")
# if all fits save new env
# it must be coherent now
# if it is not a VecEnv, make it a VecEnv
if not isinstance(env, VecEnv):
if self.verbose >= 1:
print("Wrapping the env in a DummyVecEnv.")
env = DummyVecEnv([lambda: env])
self.n_envs = env.num_envs
self.env = env
# and update observation and action space
self.observation_space = env.observation_space
self.action_space = env.action_space
def get_parameter_list(self):
def get_parameters(self):
"""
Returns policy and optimizer parameters as a tuple
:return: (dict,dict) policy_parameters, opt_parameters
@ -540,7 +544,8 @@ class BaseRLModel(object):
with archive.open(file_name + '.pth', mode="w") as opt_param_file:
th.save(dict, opt_param_file)
def excluded_save_params(self):
@staticmethod
def excluded_save_params():
"""
returns the names of the parameters that should be excluded from save
:return: ([str]) List of parameters that should be excluded from save