Add test for SAC with different entropy temperature

This commit is contained in:
Antonin Raffin 2019-12-02 11:47:52 +01:00
parent 03a84f97ea
commit 21e655ecbf

View file

@ -37,8 +37,9 @@ def test_onpolicy(model_class, env_id):
# model.load("test_save")
# os.remove("test_save.pth")
def test_sac():
@pytest.mark.parametrize("ent_coef", ['auto', 0.01])
def test_sac(ent_coef):
model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100, verbose=1, create_eval_env=True, ent_coef='auto',
learning_starts=100, verbose=1, create_eval_env=True, ent_coef=ent_coef,
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)))
model.learn(total_timesteps=1000, eval_freq=500)