Fix off policy features extractor (#198)

* Faster tests

* Fix feature extractor bug + add check

* Add missing check

* Allow TD3 features extractor to be separate

* Add share features extractor option for SAC

* Bug fixes

* Apply suggestions from code review

Co-authored-by: Adam Gleave <adam@gleave.me>

Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
Antonin RAFFIN 2020-10-27 14:24:59 +01:00 committed by GitHub
parent b252f4212c
commit 0fc0dd1b21
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 286 additions and 59 deletions

View file

@ -17,6 +17,7 @@ New Features:
- Added Hindsight Experience Replay ``HER``. (@megan-klaiber)
- ``VecNormalize`` now supports ``gym.spaces.Dict`` observation spaces
- Support logging videos to Tensorboard (@SwamyDev)
- Added ``share_features_extractor`` argument to ``SAC`` and ``TD3`` policies
Bug Fixes:
^^^^^^^^^^
@ -27,6 +28,8 @@ Bug Fixes:
- Fix model creation initializing CUDA even when `device="cpu"` is provided
- Fix ``check_env`` not checking if the env has a Dict actionspace before calling ``_check_nan`` (@wmmc88)
- Update the check for spaces unsupported by Stable Baselines 3 to include checks on the action space (@wmmc88)
- Fixed feature extractor bug for target network where the same net was shared instead
of being separate. This bug affects ``SAC``, ``DDPG`` and ``TD3`` when using ``CnnPolicy`` (or custom feature extractor)
Deprecations:
^^^^^^^^^^^^^

View file

@ -1,6 +1,6 @@
import warnings
from stable_baselines3.common.env_util import * # noqa: F403
from stable_baselines3.common.env_util import * # noqa: F403,F401
warnings.warn(
"Module ``common.cmd_util`` has been renamed to ``common.env_util`` and will be removed in the future.", FutureWarning

View file

@ -505,7 +505,7 @@ class StateDependentNoiseDistribution(Distribution):
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:param latent_sde_dim: Dimension of the last layer of the feature extractor
:param latent_sde_dim: Dimension of the last layer of the features extractor
for gSDE. By default, it is shared with the policy network.
:return:
"""

View file

@ -37,7 +37,7 @@ class BaseModel(nn.Module, ABC):
:param action_space: The action space of the environment
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
to pass to the features extractor.
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param normalize_images: Whether to normalize images or not,
@ -83,6 +83,30 @@ class BaseModel(nn.Module, ABC):
def forward(self, *args, **kwargs):
del args, kwargs
def _update_features_extractor(
self, net_kwargs: Dict[str, Any], features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Dict[str, Any]:
"""
Update the network keyword arguments and create a new features extractor object if needed.
If a ``features_extractor`` object is passed, then it will be shared.
:param net_kwargs: the base network keyword arguments, without the ones
related to features extractor
:param features_extractor: a features extractor object.
If None, a new object will be created.
:return: The updated keyword arguments
"""
net_kwargs = net_kwargs.copy()
if features_extractor is None:
# The features extractor is not shared, create a new one
features_extractor = self.make_features_extractor()
net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim))
return net_kwargs
def make_features_extractor(self) -> BaseFeaturesExtractor:
""" Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
def extract_features(self, obs: th.Tensor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
@ -90,7 +114,7 @@ class BaseModel(nn.Module, ABC):
:param obs:
:return:
"""
assert self.features_extractor is not None, "No feature extractor was set"
assert self.features_extractor is not None, "No features extractor was set"
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return self.features_extractor(preprocessed_obs)
@ -327,7 +351,7 @@ class ActorCriticPolicy(BasePolicy):
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
@ -466,7 +490,7 @@ class ActorCriticPolicy(BasePolicy):
latent_dim_pi = self.mlp_extractor.latent_dim_pi
# Separate feature extractor for gSDE
# Separate features extractor for gSDE
if self.sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
self.features_dim, self.sde_net_arch, self.activation_fn
@ -496,7 +520,7 @@ class ActorCriticPolicy(BasePolicy):
if self.ortho_init:
# TODO: check for features_extractor
# Values from stable-baselines.
# feature_extractor/mlp values are
# features_extractor/mlp values are
# originally from openai/baselines (default gains/init_scales).
module_gains = {
self.features_extractor: np.sqrt(2),
@ -625,7 +649,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
@ -698,6 +722,8 @@ class ContinuousCritic(BaseModel):
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether the features extractor is shared or not
between the actor and the critic (this saves computation time)
"""
def __init__(
@ -710,6 +736,7 @@ class ContinuousCritic(BaseModel):
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super().__init__(
observation_space,
@ -720,6 +747,7 @@ class ContinuousCritic(BaseModel):
action_dim = get_action_dim(self.action_space)
self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks = []
for idx in range(n_critics):
@ -730,7 +758,8 @@ class ContinuousCritic(BaseModel):
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
with th.no_grad():
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)

View file

@ -244,7 +244,7 @@ def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> T
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.
.. note::
Compared to their on-policy counterparts, no shared layers (other than the feature extractor)
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
between the actor and the critic are allowed (to prevent issues with target networks).
:param net_arch: The specification of the actor and critic networks.

View file

@ -90,7 +90,7 @@ class DQNPolicy(BasePolicy):
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
@ -158,10 +158,9 @@ class DQNPolicy(BasePolicy):
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def make_q_net(self) -> QNetwork:
# Make sure we always have separate networks for feature extractors etc
features_extractor = self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
features_dim = features_extractor.features_dim
return QNetwork(features_extractor=features_extractor, features_dim=features_dim, **self.net_args).to(self.device)
# Make sure we always have separate networks for features extractors etc
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return QNetwork(**net_args).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)

View file

@ -90,7 +90,7 @@ class Actor(BasePolicy):
if self.use_sde:
latent_sde_dim = last_layer_dim
# Separate feature extractor for gSDE
# Separate features extractor for gSDE
if sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
features_dim, sde_net_arch, activation_fn
@ -211,7 +211,7 @@ class SACPolicy(BasePolicy):
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
@ -219,6 +219,8 @@ class SACPolicy(BasePolicy):
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
@ -239,6 +241,7 @@ class SACPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(SACPolicy, self).__init__(
observation_space,
@ -258,17 +261,11 @@ class SACPolicy(BasePolicy):
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
# Create shared features extractor
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"features_extractor": self.features_extractor,
"features_dim": self.features_dim,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
@ -283,10 +280,17 @@ class SACPolicy(BasePolicy):
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update({"n_critics": n_critics, "net_arch": critic_arch})
self.critic_kwargs.update(
{
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
}
)
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
self.share_features_extractor = share_features_extractor
self._build(lr_schedule)
@ -294,13 +298,21 @@ class SACPolicy(BasePolicy):
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.critic = self.make_critic()
self.critic_target = self.make_critic()
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
# Do not optimize the shared features extractor with the critic loss
# otherwise, there are gradient computation issues
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
else:
# Create a separate features extractor for the critic
# this requires more memory and computation
self.critic = self.make_critic(features_extractor=None)
critic_parameters = self.critic.parameters()
# Critic target should not share the features extractor with critic
self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict())
# Do not optimize the shared feature extractor with the critic loss
# otherwise, there are gradient computation issues
# Another solution: having duplicated features extractor but requires more memory and computation
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
def _get_data(self) -> Dict[str, Any]:
@ -333,11 +345,13 @@ class SACPolicy(BasePolicy):
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device)
def make_critic(self) -> ContinuousCritic:
return ContinuousCritic(**self.critic_kwargs).to(self.device)
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
@ -375,6 +389,8 @@ class CnnPolicy(SACPolicy):
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
@ -395,6 +411,7 @@ class CnnPolicy(SACPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(CnnPolicy, self).__init__(
observation_space,
@ -413,6 +430,7 @@ class CnnPolicy(SACPolicy):
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)

View file

@ -92,7 +92,7 @@ class TD3Policy(BasePolicy):
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
@ -100,6 +100,8 @@ class TD3Policy(BasePolicy):
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
@ -115,6 +117,7 @@ class TD3Policy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(TD3Policy, self).__init__(
observation_space,
@ -135,35 +138,54 @@ class TD3Policy(BasePolicy):
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"features_extractor": self.features_extractor,
"features_dim": self.features_dim,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.actor_kwargs = self.net_args.copy()
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update({"n_critics": n_critics, "net_arch": critic_arch})
self.critic_kwargs.update(
{
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
}
)
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
self.share_features_extractor = share_features_extractor
self._build(lr_schedule)
def _build(self, lr_schedule: Callable) -> None:
self.actor = self.make_actor()
self.actor_target = self.make_actor()
# Create actor and target
# the features extractor should not be shared
self.actor = self.make_actor(features_extractor=None)
self.actor_target = self.make_actor(features_extractor=None)
# Initialize the target to have the same weights as the actor
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.critic = self.make_critic()
self.critic_target = self.make_critic()
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
# Critic target should not share the features extactor with critic
# but it can share it with the actor target as actor and critic are sharing
# the same features_extractor too
# NOTE: as a result the effective poliak (soft-copy) coefficient for the features extractor
# will be 2 * tau instead of tau (updated one time with the actor, a second time with the critic)
self.critic_target = self.make_critic(features_extractor=self.actor_target.features_extractor)
else:
# Create new features extractor for each network
self.critic = self.make_critic(features_extractor=None)
self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
@ -180,15 +202,18 @@ class TD3Policy(BasePolicy):
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
share_features_extractor=self.share_features_extractor,
)
)
return data
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device)
def make_critic(self) -> ContinuousCritic:
return ContinuousCritic(**self.critic_kwargs).to(self.device)
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(observation, deterministic=deterministic)
@ -211,7 +236,7 @@ class CnnPolicy(TD3Policy):
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
@ -219,6 +244,8 @@ class CnnPolicy(TD3Policy):
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
@ -234,6 +261,7 @@ class CnnPolicy(TD3Policy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(CnnPolicy, self).__init__(
observation_space,
@ -247,6 +275,7 @@ class CnnPolicy(TD3Policy):
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)

View file

@ -1,10 +1,13 @@
import os
from copy import deepcopy
import numpy as np
import pytest
import torch as th
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.utils import zip_strict
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
@ -35,3 +38,139 @@ def test_cnn(tmp_path, model_class):
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
os.remove(str(tmp_path / SAVE_NAME))
def patch_dqn_names_(model):
# Small hack to make the test work with DQN
if isinstance(model, DQN):
model.critic = model.q_net
model.critic_target = model.q_net_target
def params_should_match(params, other_params):
for param, other_param in zip_strict(params, other_params):
assert th.allclose(param, other_param)
def params_should_differ(params, other_params):
for param, other_param in zip_strict(params, other_params):
assert not th.allclose(param, other_param)
def check_td3_feature_extractor_match(model):
for (key, actor_param), critic_param in zip(model.actor_target.named_parameters(), model.critic_target.parameters()):
if "features_extractor" in key:
assert th.allclose(actor_param, critic_param), key
def check_td3_feature_extractor_differ(model):
for (key, actor_param), critic_param in zip(model.actor_target.named_parameters(), model.critic_target.parameters()):
if "features_extractor" in key:
assert not th.allclose(actor_param, critic_param), key
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@pytest.mark.parametrize("share_features_extractor", [True, False])
def test_features_extractor_target_net(model_class, share_features_extractor):
if model_class == DQN and share_features_extractor:
pytest.skip()
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
if model_class != DQN:
kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor
model = model_class("CnnPolicy", env, seed=0, **kwargs)
patch_dqn_names_(model)
if share_features_extractor:
# Check that the objects are the same and not just copied
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
if model_class == TD3:
assert id(model.policy.actor_target.features_extractor) == id(model.policy.critic_target.features_extractor)
# Actor and critic feature extractor should be the same
td3_features_extractor_check = check_td3_feature_extractor_match
else:
# Actor and critic feature extractor should differ same
td3_features_extractor_check = check_td3_feature_extractor_differ
# Check that the object differ
if model_class != DQN:
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
if model_class == TD3:
assert id(model.policy.actor_target.features_extractor) != id(model.policy.critic_target.features_extractor)
# Critic and target should be equal at the begginning of training
params_should_match(model.critic.parameters(), model.critic_target.parameters())
# TD3 has also a target actor net
if model_class == TD3:
params_should_match(model.actor.parameters(), model.actor_target.parameters())
model.learn(200)
# Critic and target should differ
params_should_differ(model.critic.parameters(), model.critic_target.parameters())
if model_class == TD3:
params_should_differ(model.actor.parameters(), model.actor_target.parameters())
td3_features_extractor_check(model)
# Re-initialize and collect some random data (without doing gradient steps,
# since 10 < learning_starts = 100)
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
patch_dqn_names_(model)
original_param = deepcopy(list(model.critic.parameters()))
original_target_param = deepcopy(list(model.critic_target.parameters()))
if model_class == TD3:
original_actor_target_param = deepcopy(list(model.actor_target.parameters()))
# Deactivate copy to target
model.tau = 0.0
model.train(gradient_steps=1)
# Target should be the same
params_should_match(original_target_param, model.critic_target.parameters())
if model_class == TD3:
params_should_match(original_actor_target_param, model.actor_target.parameters())
td3_features_extractor_check(model)
# not the same for critic net (updated by gradient descent)
params_should_differ(original_param, model.critic.parameters())
# Update the reference as it should not change in the next step
original_param = deepcopy(list(model.critic.parameters()))
if model_class == TD3:
original_actor_param = deepcopy(list(model.actor.parameters()))
# Deactivate learning rate
model.lr_schedule = lambda _: 0.0
# Re-activate polyak update
model.tau = 0.01
# Special case for DQN: target net is updated in the `collect_rollout()`
# not the `train()` method
if model_class == DQN:
model.target_update_interval = 1
model._on_step()
model.train(gradient_steps=1)
# Target should have changed now (due to polyak update)
params_should_differ(original_target_param, model.critic_target.parameters())
# Critic should be the same
params_should_match(original_param, model.critic.parameters())
if model_class == TD3:
params_should_differ(original_actor_target_param, model.actor_target.parameters())
params_should_match(original_actor_param, model.actor.parameters())
td3_features_extractor_check(model)

View file

@ -20,9 +20,10 @@ def test_deterministic_pg(model_class, action_noise):
learning_starts=100,
verbose=1,
create_eval_env=True,
buffer_size=250,
action_noise=action_noise,
)
model.learn(total_timesteps=1000, eval_freq=500)
model.learn(total_timesteps=300, eval_freq=250)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
@ -50,6 +51,7 @@ def test_ppo(env_id, clip_range_vf):
model = PPO(
"MlpPolicy",
env_id,
n_steps=512,
seed=0,
policy_kwargs=dict(net_arch=[16]),
verbose=1,
@ -68,19 +70,25 @@ def test_sac(ent_coef):
learning_starts=100,
verbose=1,
create_eval_env=True,
buffer_size=250,
ent_coef=ent_coef,
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
)
model.learn(total_timesteps=1000, eval_freq=500)
model.learn(total_timesteps=300, eval_freq=250)
@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG
model = SAC(
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics),
learning_starts=100,
buffer_size=10000,
verbose=1,
)
model.learn(total_timesteps=500)
model.learn(total_timesteps=300)
def test_dqn():

View file

@ -303,21 +303,23 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
:param model_class: (BaseAlgorithm) A RL model
:param policy_str: (str) Name of the policy.
"""
kwargs = {}
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [SAC, TD3, DQN, DDPG]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250)
kwargs = dict(
buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
model.learn(total_timesteps=500)
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=300)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)

View file

@ -195,4 +195,4 @@ def test_zip_strict():
def test_cmd_util_rename():
"""Test that importing cmd_util still works but raises warning"""
with pytest.warns(FutureWarning):
from stable_baselines3.common.cmd_util import make_vec_env
from stable_baselines3.common.cmd_util import make_vec_env # noqa: F401