stable-baselines3/tests/test_custom_policy.py
Antonin RAFFIN 000544cc1f
Add support for pre and post linear modules in create_mlp (#1975)
* Add support for pre and post linear modules in `create_mlp`

* Disable mypy for python 3.8

* Reformat toml file

* Update docstring

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Add some comments

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-07-22 13:42:33 +02:00

120 lines
4.5 KiB
Python

import pytest
import torch as th
import torch.nn as nn
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
from stable_baselines3.common.torch_layers import create_mlp
@pytest.mark.parametrize(
"net_arch",
[
[],
[4],
[4, 4],
dict(vf=[16], pi=[8]),
dict(vf=[8, 4], pi=[8]),
dict(vf=[8], pi=[8, 4]),
dict(pi=[8]),
# Old format, emits a warning
[dict(vf=[8])],
[dict(vf=[8], pi=[4])],
],
)
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_flexible_mlp(model_class, net_arch):
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
with pytest.warns(UserWarning):
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)
else:
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)
@pytest.mark.parametrize("net_arch", [[], [4], [4, 4], dict(qf=[8], pi=[8, 4])])
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_custom_offpolicy(model_class, net_arch):
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300)
@pytest.mark.parametrize("model_class", [A2C, DQN, PPO, SAC, TD3])
@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
def test_custom_optimizer(model_class, optimizer_kwargs):
# Use different environment for DQN
if model_class is DQN:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"
kwargs = {}
if model_class in {DQN, SAC, TD3}:
kwargs = dict(learning_starts=100)
elif model_class in {A2C, PPO}:
kwargs = dict(n_steps=64)
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
_ = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, **kwargs).learn(300)
def test_tf_like_rmsprop_optimizer():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = A2C("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs).learn(500)
def test_dqn_custom_policy():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)
def test_create_mlp():
net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True)
# We cannot compare the network directly because the modules have different ids
# assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2),
# nn.Tanh()]
assert len(net) == 6
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 4
assert net[0].out_features == 16
assert isinstance(net[1], nn.ReLU)
assert isinstance(net[2], nn.Linear)
assert isinstance(net[4], nn.Linear)
assert net[4].in_features == 8
assert net[4].out_features == 2
assert isinstance(net[5], nn.Tanh)
# Linear network
net = create_mlp(4, -1, net_arch=[])
assert net == []
# No output layer, with custom activation function
net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh)
# assert net == [nn.Linear(6, 8), nn.Tanh()]
assert len(net) == 2
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 6
assert net[0].out_features == 8
assert isinstance(net[1], nn.Tanh)
# Using pre-linear and post-linear modules
pre_linear = [nn.BatchNorm1d]
post_linear = [nn.LayerNorm]
net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear)
# assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU()
# nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(),
# nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear
assert len(net) == 10
assert isinstance(net[0], nn.BatchNorm1d)
assert net[0].num_features == 6
assert isinstance(net[1], nn.Linear)
assert isinstance(net[2], nn.LayerNorm)
assert isinstance(net[3], nn.ReLU)
assert isinstance(net[4], nn.BatchNorm1d)
assert isinstance(net[5], nn.Linear)
assert net[5].in_features == 8
assert net[5].out_features == 12
assert isinstance(net[6], nn.LayerNorm)
assert isinstance(net[7], nn.ReLU)
assert isinstance(net[8], nn.BatchNorm1d)
assert isinstance(net[-1], nn.Linear)
assert net[-1].in_features == 12
assert net[-1].out_features == 2