mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-08 00:23:22 +00:00
added set_env test and set_env wrapping
This commit is contained in:
parent
cf1d7118a5
commit
88d4f44d55
2 changed files with 34 additions and 8 deletions
|
|
@ -7,7 +7,7 @@ import torch as th
|
|||
|
||||
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv
|
||||
from torchy_baselines.common.identity_env import IdentityEnvBox
|
||||
from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv
|
||||
|
||||
MODEL_LIST = [
|
||||
CEMRL,
|
||||
|
|
@ -101,7 +101,28 @@ def test_save_load(model_class):
|
|||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_set_env(model_class):
|
||||
pass
|
||||
"""
|
||||
Test if set_env function does work correct
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
"""
|
||||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
env2 = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
env3 = IdentityEnvBox(10)
|
||||
|
||||
# create model
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), create_eval_env=True)
|
||||
# learn
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# change env
|
||||
model.set_env(env2)
|
||||
# learn again
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# change env test wrapping
|
||||
model.set_env(env3)
|
||||
# learn again
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
|
|
|
|||
|
|
@ -185,6 +185,7 @@ class BaseRLModel(object):
|
|||
def set_env(self, env):
|
||||
"""
|
||||
Checks the validity of the environment, and if it is coherent, set it as the current environment.
|
||||
Furthermore wrap any non vectorized env into a vectorized
|
||||
checked parameters:
|
||||
- observation_space
|
||||
- action_space
|
||||
|
|
@ -193,13 +194,16 @@ class BaseRLModel(object):
|
|||
"""
|
||||
if self.check_env(env, self.observation_space, self.action_space) is False:
|
||||
raise ValueError("Given environment is not compatible with model")
|
||||
# if all fits save new env
|
||||
# it must be coherent now
|
||||
# if it is not a VecEnv, make it a VecEnv
|
||||
if not isinstance(env, VecEnv):
|
||||
if self.verbose >= 1:
|
||||
print("Wrapping the env in a DummyVecEnv.")
|
||||
env = DummyVecEnv([lambda: env])
|
||||
self.n_envs = env.num_envs
|
||||
self.env = env
|
||||
# and update observation and action space
|
||||
self.observation_space = env.observation_space
|
||||
self.action_space = env.action_space
|
||||
|
||||
def get_parameter_list(self):
|
||||
def get_parameters(self):
|
||||
"""
|
||||
Returns policy and optimizer parameters as a tuple
|
||||
:return: (dict,dict) policy_parameters, opt_parameters
|
||||
|
|
@ -540,7 +544,8 @@ class BaseRLModel(object):
|
|||
with archive.open(file_name + '.pth', mode="w") as opt_param_file:
|
||||
th.save(dict, opt_param_file)
|
||||
|
||||
def excluded_save_params(self):
|
||||
@staticmethod
|
||||
def excluded_save_params():
|
||||
"""
|
||||
returns the names of the parameters that should be excluded from save
|
||||
:return: ([str]) List of parameters that should be excluded from save
|
||||
|
|
|
|||
Loading…
Reference in a new issue