From f3cb0688c4e9fafb7349348303b09b7de9cf44c1 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 22 Apr 2020 13:21:11 +0200 Subject: [PATCH] Fix custom optimizer --- tests/test_custom_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 94cbf17..dac69dc 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -30,5 +30,5 @@ def test_custom_offpolicy(model_class, net_arch): @pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3]) @pytest.mark.parametrize('optimizer_kwargs', [None, dict(weight_decay=0.0)]) def test_custom_optimizer(model_class, optimizer_kwargs): - policy_kwargs = dict(optimizer=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32]) + policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32]) _ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=policy_kwargs).learn(1000)