diff --git a/tests/test_run.py b/tests/test_run.py index 445947b..a5569e5 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -15,7 +15,7 @@ def test_td3(): model.learn(total_timesteps=1000, eval_freq=500) model.save("test_save") model.load("test_save") - os.remove("test_save.pth") + os.remove("test_save.zip") def test_cemrl(): @@ -24,7 +24,7 @@ def test_cemrl(): model.learn(total_timesteps=1000, eval_freq=500) model.save("test_save") model.load("test_save") - os.remove("test_save.pth") + os.remove("test_save.zip") @pytest.mark.parametrize("model_class", [A2C, PPO]) @@ -32,9 +32,9 @@ def test_cemrl(): def test_onpolicy(model_class, env_id): model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=1000, eval_freq=500) - #model.save("test_save") - #model.load("test_save") - #os.remove("test_save.pth") + model.save("test_save") + model.load("test_save") + os.remove("test_save.zip") def test_sac(): @@ -42,3 +42,6 @@ def test_sac(): learning_starts=100, verbose=1, create_eval_env=True, ent_coef='auto', action_noise=NormalActionNoise(np.zeros(1), np.zeros(1))) model.learn(total_timesteps=1000, eval_freq=500) + model.save("test_save") + model.load("test_save") + os.remove("test_save.zip")