stable-baselines3/tests/test_custom_policy.py
Anssi 2cd6a4f93b
Match performance with stable-baselines (discrete case) (#110)
* Fix storing correct episode dones

* Fix number of filters in NatureCNN network

* Add TF-like RMSprop for matching performance with sb2

* Remove stuff that was accidentally included

* Reformat

* Clarify variable naming

* Update changelog

* Add comment on RMSprop implementations to A2C

* Add test for RMSpropTFLike

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-08-03 22:22:51 +02:00

40 lines
1.4 KiB
Python

import pytest
import torch as th
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
@pytest.mark.parametrize(
"net_arch",
[
[12, dict(vf=[16], pi=[8])],
[4],
[],
[4, 4],
[12, dict(vf=[8, 4], pi=[8])],
[12, dict(vf=[8], pi=[8, 4])],
[12, dict(pi=[8])],
],
)
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_flexible_mlp(model_class, net_arch):
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
@pytest.mark.parametrize("net_arch", [[4], [4, 4],])
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_custom_offpolicy(model_class, net_arch):
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch)).learn(1000)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
def test_custom_optimizer(model_class, optimizer_kwargs):
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)
def test_tf_like_rmsprop_optimizer():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = A2C("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)