diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 94cbf17..dac69dc 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -30,5 +30,5 @@ def test_custom_offpolicy(model_class, net_arch): @pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3]) @pytest.mark.parametrize('optimizer_kwargs', [None, dict(weight_decay=0.0)]) def test_custom_optimizer(model_class, optimizer_kwargs): - policy_kwargs = dict(optimizer=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32]) + policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32]) _ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=policy_kwargs).learn(1000)