diff --git a/tests/test_run.py b/tests/test_run.py index 32a4b30..b29a526 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -37,8 +37,9 @@ def test_onpolicy(model_class, env_id): # model.load("test_save") # os.remove("test_save.pth") -def test_sac(): +@pytest.mark.parametrize("ent_coef", ['auto', 0.01]) +def test_sac(ent_coef): model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), - learning_starts=100, verbose=1, create_eval_env=True, ent_coef='auto', + learning_starts=100, verbose=1, create_eval_env=True, ent_coef=ent_coef, action_noise=NormalActionNoise(np.zeros(1), np.zeros(1))) model.learn(total_timesteps=1000, eval_freq=500)