mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-04 04:07:27 +00:00
Fix custom optimizer
This commit is contained in:
parent
041f2bc59a
commit
f3cb0688c4
1 changed files with 1 additions and 1 deletions
|
|
@ -30,5 +30,5 @@ def test_custom_offpolicy(model_class, net_arch):
|
|||
@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3])
|
||||
@pytest.mark.parametrize('optimizer_kwargs', [None, dict(weight_decay=0.0)])
|
||||
def test_custom_optimizer(model_class, optimizer_kwargs):
|
||||
policy_kwargs = dict(optimizer=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
|
||||
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
|
||||
_ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=policy_kwargs).learn(1000)
|
||||
|
|
|
|||
Loading…
Reference in a new issue