From c82025e673344e76f8015b2da2f0f40e6217c98a Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 16:07:15 +0100 Subject: [PATCH] Add Test for exclude/include feature of save --- tests/test_save_load.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index f3f8d10..36dfdc0 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -97,3 +97,37 @@ def test_save_load(model_class): # clear file from os os.remove("test_save.zip") + + +@pytest.mark.parametrize("model_class", MODEL_LIST) +def test_exclude_include_saved_params(model_class): + """ + Test if exclude and include parameters of save() work + + :param model_class: (BaseRLModel) A RL model + """ + env = DummyVecEnv([lambda: IdentityEnvBox(10)]) + + # create model + model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + # set verbose as something different then standard settings + model.verbose = 2 + + # Check if exclude works + model.save("test_save.zip", exclude=["verbose"]) + del model + model = model_class.load("test_save") + # check if verbose was not saved + assert not model.verbose == 2 + + # set verbose as something different then standard settings + model.verbose = 2 + # Check if include works + model.save("test_save.zip", exclude=["verbose"], include=["verbose"]) + del model + model = model_class.load("test_save") + assert model.verbose == 2 + + + # clear file from os + os.remove("test_save.zip")