mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-04 04:07:27 +00:00
Corrected test_run.py
This commit is contained in:
parent
924ba9aea6
commit
cfb822aa91
1 changed files with 8 additions and 5 deletions
|
|
@ -15,7 +15,7 @@ def test_td3():
|
|||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.pth")
|
||||
os.remove("test_save.zip")
|
||||
|
||||
|
||||
def test_cemrl():
|
||||
|
|
@ -24,7 +24,7 @@ def test_cemrl():
|
|||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.pth")
|
||||
os.remove("test_save.zip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
|
|
@ -32,9 +32,9 @@ def test_cemrl():
|
|||
def test_onpolicy(model_class, env_id):
|
||||
model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
#model.save("test_save")
|
||||
#model.load("test_save")
|
||||
#os.remove("test_save.pth")
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.zip")
|
||||
|
||||
|
||||
def test_sac():
|
||||
|
|
@ -42,3 +42,6 @@ def test_sac():
|
|||
learning_starts=100, verbose=1, create_eval_env=True, ent_coef='auto',
|
||||
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)))
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.zip")
|
||||
|
|
|
|||
Loading…
Reference in a new issue