mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Add Test for exclude/include feature of save
This commit is contained in:
parent
ee6f938ddc
commit
c82025e673
1 changed files with 34 additions and 0 deletions
|
|
@ -97,3 +97,37 @@ def test_save_load(model_class):
|
|||
|
||||
# clear file from os
|
||||
os.remove("test_save.zip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_exclude_include_saved_params(model_class):
|
||||
"""
|
||||
Test if exclude and include parameters of save() work
|
||||
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
"""
|
||||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
|
||||
# create model
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
# set verbose as something different then standard settings
|
||||
model.verbose = 2
|
||||
|
||||
# Check if exclude works
|
||||
model.save("test_save.zip", exclude=["verbose"])
|
||||
del model
|
||||
model = model_class.load("test_save")
|
||||
# check if verbose was not saved
|
||||
assert not model.verbose == 2
|
||||
|
||||
# set verbose as something different then standard settings
|
||||
model.verbose = 2
|
||||
# Check if include works
|
||||
model.save("test_save.zip", exclude=["verbose"], include=["verbose"])
|
||||
del model
|
||||
model = model_class.load("test_save")
|
||||
assert model.verbose == 2
|
||||
|
||||
|
||||
# clear file from os
|
||||
os.remove("test_save.zip")
|
||||
|
|
|
|||
Loading…
Reference in a new issue