Corrected test_run.py

This commit is contained in:
Noah Dormann 2019-11-21 16:54:30 +01:00
parent 924ba9aea6
commit cfb822aa91

View file

@ -15,7 +15,7 @@ def test_td3():
model.learn(total_timesteps=1000, eval_freq=500)
model.save("test_save")
model.load("test_save")
os.remove("test_save.pth")
os.remove("test_save.zip")
def test_cemrl():
@ -24,7 +24,7 @@ def test_cemrl():
model.learn(total_timesteps=1000, eval_freq=500)
model.save("test_save")
model.load("test_save")
os.remove("test_save.pth")
os.remove("test_save.zip")
@pytest.mark.parametrize("model_class", [A2C, PPO])
@ -32,9 +32,9 @@ def test_cemrl():
def test_onpolicy(model_class, env_id):
model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model.learn(total_timesteps=1000, eval_freq=500)
#model.save("test_save")
#model.load("test_save")
#os.remove("test_save.pth")
model.save("test_save")
model.load("test_save")
os.remove("test_save.zip")
def test_sac():
@ -42,3 +42,6 @@ def test_sac():
learning_starts=100, verbose=1, create_eval_env=True, ent_coef='auto',
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)))
model.learn(total_timesteps=1000, eval_freq=500)
model.save("test_save")
model.load("test_save")
os.remove("test_save.zip")