mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
* Modified actor-critic policies & MlpExtractor class ActorCriticPolicy: - changed type hint of net_arch param: now it's a dict - removed check that if features extractor is not shared: no shared layers are allowed in the mlp_extractor regardless of the features extractor ActorCriticCnnPolicy: - changed type hint of net_arch param: now it's a dict MultiInputActorcriticPolicy: - changed type hint of net_arch param: now it's a dict MlpExtractor: - changed type hint of net_arch param: now it's a dict - adapted networks creation - adapted methods: forward, forward_actor & forward_critic * Removed shared layers in mlp_extractor * Updated docs and changelog + reformat * Updated custom policy tests * Removed test on deprecation warning for share layers in mlp_extractor Now shared layers are removed * Update version * Update RL Zoo doc * Fix linter warnings * Add ruff to Makefile (experimental) * Add backward compat code and minor updates * Update tests * Add backward compatibility * Fix test * Improve compat code Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
import pytest
|
|
import torch as th
|
|
|
|
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
|
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"net_arch",
|
|
[
|
|
[],
|
|
[4],
|
|
[4, 4],
|
|
dict(vf=[16], pi=[8]),
|
|
dict(vf=[8, 4], pi=[8]),
|
|
dict(vf=[8], pi=[8, 4]),
|
|
dict(pi=[8]),
|
|
# Old format, emits a warning
|
|
[dict(vf=[8])],
|
|
[dict(vf=[8], pi=[4])],
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
|
def test_flexible_mlp(model_class, net_arch):
|
|
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
|
|
with pytest.warns(UserWarning):
|
|
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)
|
|
else:
|
|
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)
|
|
|
|
|
|
@pytest.mark.parametrize("net_arch", [[], [4], [4, 4], dict(qf=[8], pi=[8, 4])])
|
|
@pytest.mark.parametrize("model_class", [SAC, TD3])
|
|
def test_custom_offpolicy(model_class, net_arch):
|
|
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300)
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", [A2C, DQN, PPO, SAC, TD3])
|
|
@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
|
|
def test_custom_optimizer(model_class, optimizer_kwargs):
|
|
# Use different environment for DQN
|
|
if model_class is DQN:
|
|
env_id = "CartPole-v1"
|
|
else:
|
|
env_id = "Pendulum-v1"
|
|
|
|
kwargs = {}
|
|
if model_class in {DQN, SAC, TD3}:
|
|
kwargs = dict(learning_starts=100)
|
|
elif model_class in {A2C, PPO}:
|
|
kwargs = dict(n_steps=64)
|
|
|
|
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
|
|
_ = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, **kwargs).learn(300)
|
|
|
|
|
|
def test_tf_like_rmsprop_optimizer():
|
|
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
|
|
_ = A2C("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs).learn(500)
|
|
|
|
|
|
def test_dqn_custom_policy():
|
|
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
|
|
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)
|