stable-baselines3/tests/test_custom_policy.py
Alex Pasquali 2cfcec4f50
Modified ActorCriticPolicy to support non-shared features extractor (#1148)
* Modified ActorCriticPolicy to support non-shared features extractor

* Refactored features extraction with non-shared features extractor in ActorCriticPolicy and updated doc

Doc update: added 'warning' on custom policy docs that says that, if the features extractor is non-shared, it's not possible to have shared layers in the mlp_extractor

* Moved attrib share_features_extractor in class

* Updated custom policy doc for non-shared features extractor

* Updated changelog

* Made some if-statements more readable if policies.py

The if-statements are related to the shared/non-shared features extractor in ActorCritic policies

* Simplify implementation and add run test

* Keep order in module gain to keep previous results consistents

* Fix test

* Improved docstring in policies.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Added some tests

* feature extractor -> features extractor

* Fix test

* Fix env_id in test

* Make features extractor parameter explicit

* Remove duplicate

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-12-20 15:12:05 +01:00

64 lines
2.3 KiB
Python

import pytest
import torch as th
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
@pytest.mark.parametrize(
"net_arch",
[
[12, dict(vf=[16], pi=[8])],
[4],
[],
[4, 4],
[12, dict(vf=[8, 4], pi=[8])],
[12, dict(vf=[8], pi=[8, 4])],
[12, dict(pi=[8])],
],
)
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_flexible_mlp(model_class, net_arch):
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)
@pytest.mark.parametrize("net_arch", [[], [4], [4, 4], dict(qf=[8], pi=[8, 4])])
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_custom_offpolicy(model_class, net_arch):
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300)
@pytest.mark.parametrize("model_class", [A2C, DQN, PPO, SAC, TD3])
@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
def test_custom_optimizer(model_class, optimizer_kwargs):
# Use different environment for DQN
if model_class is DQN:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"
kwargs = {}
if model_class in {DQN, SAC, TD3}:
kwargs = dict(learning_starts=100)
elif model_class in {A2C, PPO}:
kwargs = dict(n_steps=64)
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
_ = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, **kwargs).learn(300)
def test_tf_like_rmsprop_optimizer():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = A2C("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs).learn(500)
def test_dqn_custom_policy():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_not_shared_features_extractor(model_class):
policy_kwargs = dict(net_arch=[12, dict(vf=[16], pi=[8])], share_features_extractor=False)
with pytest.raises(ValueError):
model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs)