From 0fc0dd1b21877fd6aef961cd2f48926b8cd38354 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 27 Oct 2020 14:24:59 +0100 Subject: [PATCH] Fix off policy features extractor (#198) * Faster tests * Fix feature extractor bug + add check * Add missing check * Allow TD3 features extractor to be separate * Add share features extractor option for SAC * Bug fixes * Apply suggestions from code review Co-authored-by: Adam Gleave Co-authored-by: Adam Gleave --- docs/misc/changelog.rst | 3 + stable_baselines3/common/cmd_util.py | 2 +- stable_baselines3/common/distributions.py | 2 +- stable_baselines3/common/policies.py | 43 +++++-- stable_baselines3/common/torch_layers.py | 2 +- stable_baselines3/dqn/policies.py | 9 +- stable_baselines3/sac/policies.py | 56 ++++++--- stable_baselines3/td3/policies.py | 61 +++++++--- tests/test_cnn.py | 139 ++++++++++++++++++++++ tests/test_run.py | 16 ++- tests/test_save_load.py | 10 +- tests/test_utils.py | 2 +- 12 files changed, 286 insertions(+), 59 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 09957f4..965c9da 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ New Features: - Added Hindsight Experience Replay ``HER``. (@megan-klaiber) - ``VecNormalize`` now supports ``gym.spaces.Dict`` observation spaces - Support logging videos to Tensorboard (@SwamyDev) +- Added ``share_features_extractor`` argument to ``SAC`` and ``TD3`` policies Bug Fixes: ^^^^^^^^^^ @@ -27,6 +28,8 @@ Bug Fixes: - Fix model creation initializing CUDA even when `device="cpu"` is provided - Fix ``check_env`` not checking if the env has a Dict actionspace before calling ``_check_nan`` (@wmmc88) - Update the check for spaces unsupported by Stable Baselines 3 to include checks on the action space (@wmmc88) +- Fixed feature extractor bug for target network where the same net was shared instead + of being separate. This bug affects ``SAC``, ``DDPG`` and ``TD3`` when using ``CnnPolicy`` (or custom feature extractor) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/cmd_util.py b/stable_baselines3/common/cmd_util.py index e5a30e5..dea9e79 100644 --- a/stable_baselines3/common/cmd_util.py +++ b/stable_baselines3/common/cmd_util.py @@ -1,6 +1,6 @@ import warnings -from stable_baselines3.common.env_util import * # noqa: F403 +from stable_baselines3.common.env_util import * # noqa: F403,F401 warnings.warn( "Module ``common.cmd_util`` has been renamed to ``common.env_util`` and will be removed in the future.", FutureWarning diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 0f43ca8..7a3506c 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -505,7 +505,7 @@ class StateDependentNoiseDistribution(Distribution): :param latent_dim: Dimension of the last layer of the policy (before the action layer) :param log_std_init: Initial value for the log standard deviation - :param latent_sde_dim: Dimension of the last layer of the feature extractor + :param latent_sde_dim: Dimension of the last layer of the features extractor for gSDE. By default, it is shared with the policy network. :return: """ diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index ef7ed05..d07f115 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -37,7 +37,7 @@ class BaseModel(nn.Module, ABC): :param action_space: The action space of the environment :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments - to pass to the feature extractor. + to pass to the features extractor. :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) :param normalize_images: Whether to normalize images or not, @@ -83,6 +83,30 @@ class BaseModel(nn.Module, ABC): def forward(self, *args, **kwargs): del args, kwargs + def _update_features_extractor( + self, net_kwargs: Dict[str, Any], features_extractor: Optional[BaseFeaturesExtractor] = None + ) -> Dict[str, Any]: + """ + Update the network keyword arguments and create a new features extractor object if needed. + If a ``features_extractor`` object is passed, then it will be shared. + + :param net_kwargs: the base network keyword arguments, without the ones + related to features extractor + :param features_extractor: a features extractor object. + If None, a new object will be created. + :return: The updated keyword arguments + """ + net_kwargs = net_kwargs.copy() + if features_extractor is None: + # The features extractor is not shared, create a new one + features_extractor = self.make_features_extractor() + net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim)) + return net_kwargs + + def make_features_extractor(self) -> BaseFeaturesExtractor: + """ Helper method to create a features extractor.""" + return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + def extract_features(self, obs: th.Tensor) -> th.Tensor: """ Preprocess the observation if needed and extract features. @@ -90,7 +114,7 @@ class BaseModel(nn.Module, ABC): :param obs: :return: """ - assert self.features_extractor is not None, "No feature extractor was set" + assert self.features_extractor is not None, "No features extractor was set" preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) return self.features_extractor(preprocessed_obs) @@ -327,7 +351,7 @@ class ActorCriticPolicy(BasePolicy): this allows to ensure boundaries when using gSDE. :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments - to pass to the feature extractor. + to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -466,7 +490,7 @@ class ActorCriticPolicy(BasePolicy): latent_dim_pi = self.mlp_extractor.latent_dim_pi - # Separate feature extractor for gSDE + # Separate features extractor for gSDE if self.sde_net_arch is not None: self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( self.features_dim, self.sde_net_arch, self.activation_fn @@ -496,7 +520,7 @@ class ActorCriticPolicy(BasePolicy): if self.ortho_init: # TODO: check for features_extractor # Values from stable-baselines. - # feature_extractor/mlp values are + # features_extractor/mlp values are # originally from openai/baselines (default gains/init_scales). module_gains = { self.features_extractor: np.sqrt(2), @@ -625,7 +649,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): this allows to ensure boundaries when using gSDE. :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments - to pass to the feature extractor. + to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -698,6 +722,8 @@ class ContinuousCritic(BaseModel): :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether the features extractor is shared or not + between the actor and the critic (this saves computation time) """ def __init__( @@ -710,6 +736,7 @@ class ContinuousCritic(BaseModel): activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, n_critics: int = 2, + share_features_extractor: bool = True, ): super().__init__( observation_space, @@ -720,6 +747,7 @@ class ContinuousCritic(BaseModel): action_dim = get_action_dim(self.action_space) + self.share_features_extractor = share_features_extractor self.n_critics = n_critics self.q_networks = [] for idx in range(n_critics): @@ -730,7 +758,8 @@ class ContinuousCritic(BaseModel): def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: # Learn the features extractor using the policy loss only - with th.no_grad(): + # when the features_extractor is shared with the actor + with th.set_grad_enabled(not self.share_features_extractor): features = self.extract_features(obs) qvalue_input = th.cat([features, actions], dim=1) return tuple(q_net(qvalue_input) for q_net in self.q_networks) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 62539d8..e73a684 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -244,7 +244,7 @@ def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> T then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``. .. note:: - Compared to their on-policy counterparts, no shared layers (other than the feature extractor) + Compared to their on-policy counterparts, no shared layers (other than the features extractor) between the actor and the critic are allowed (to prevent issues with target networks). :param net_arch: The specification of the actor and critic networks. diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 890d3ed..7416f02 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -90,7 +90,7 @@ class DQNPolicy(BasePolicy): :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments - to pass to the feature extractor. + to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -158,10 +158,9 @@ class DQNPolicy(BasePolicy): self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) def make_q_net(self) -> QNetwork: - # Make sure we always have separate networks for feature extractors etc - features_extractor = self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - features_dim = features_extractor.features_dim - return QNetwork(features_extractor=features_extractor, features_dim=features_dim, **self.net_args).to(self.device) + # Make sure we always have separate networks for features extractors etc + net_args = self._update_features_extractor(self.net_args, features_extractor=None) + return QNetwork(**net_args).to(self.device) def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: return self._predict(obs, deterministic=deterministic) diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 744b0d0..c79d679 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -90,7 +90,7 @@ class Actor(BasePolicy): if self.use_sde: latent_sde_dim = last_layer_dim - # Separate feature extractor for gSDE + # Separate features extractor for gSDE if sde_net_arch is not None: self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( features_dim, sde_net_arch, activation_fn @@ -211,7 +211,7 @@ class SACPolicy(BasePolicy): :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments - to pass to the feature extractor. + to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -219,6 +219,8 @@ class SACPolicy(BasePolicy): :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) """ def __init__( @@ -239,6 +241,7 @@ class SACPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, + share_features_extractor: bool = True, ): super(SACPolicy, self).__init__( observation_space, @@ -258,17 +261,11 @@ class SACPolicy(BasePolicy): actor_arch, critic_arch = get_actor_critic_arch(net_arch) - # Create shared features extractor - self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - self.features_dim = self.features_extractor.features_dim - self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { "observation_space": self.observation_space, "action_space": self.action_space, - "features_extractor": self.features_extractor, - "features_dim": self.features_dim, "net_arch": actor_arch, "activation_fn": self.activation_fn, "normalize_images": normalize_images, @@ -283,10 +280,17 @@ class SACPolicy(BasePolicy): } self.actor_kwargs.update(sde_kwargs) self.critic_kwargs = self.net_args.copy() - self.critic_kwargs.update({"n_critics": n_critics, "net_arch": critic_arch}) + self.critic_kwargs.update( + { + "n_critics": n_critics, + "net_arch": critic_arch, + "share_features_extractor": share_features_extractor, + } + ) self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None + self.share_features_extractor = share_features_extractor self._build(lr_schedule) @@ -294,13 +298,21 @@ class SACPolicy(BasePolicy): self.actor = self.make_actor() self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) - self.critic = self.make_critic() - self.critic_target = self.make_critic() + if self.share_features_extractor: + self.critic = self.make_critic(features_extractor=self.actor.features_extractor) + # Do not optimize the shared features extractor with the critic loss + # otherwise, there are gradient computation issues + critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] + else: + # Create a separate features extractor for the critic + # this requires more memory and computation + self.critic = self.make_critic(features_extractor=None) + critic_parameters = self.critic.parameters() + + # Critic target should not share the features extractor with critic + self.critic_target = self.make_critic(features_extractor=None) self.critic_target.load_state_dict(self.critic.state_dict()) - # Do not optimize the shared feature extractor with the critic loss - # otherwise, there are gradient computation issues - # Another solution: having duplicated features extractor but requires more memory and computation - critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] + self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) def _get_data(self) -> Dict[str, Any]: @@ -333,11 +345,13 @@ class SACPolicy(BasePolicy): """ self.actor.reset_noise(batch_size=batch_size) - def make_actor(self) -> Actor: - return Actor(**self.actor_kwargs).to(self.device) + def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: + actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) + return Actor(**actor_kwargs).to(self.device) - def make_critic(self) -> ContinuousCritic: - return ContinuousCritic(**self.critic_kwargs).to(self.device) + def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic: + critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) + return ContinuousCritic(**critic_kwargs).to(self.device) def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: return self._predict(obs, deterministic=deterministic) @@ -375,6 +389,8 @@ class CnnPolicy(SACPolicy): :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) """ def __init__( @@ -395,6 +411,7 @@ class CnnPolicy(SACPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, + share_features_extractor: bool = True, ): super(CnnPolicy, self).__init__( observation_space, @@ -413,6 +430,7 @@ class CnnPolicy(SACPolicy): optimizer_class, optimizer_kwargs, n_critics, + share_features_extractor, ) diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index a4d5cea..0083a49 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -92,7 +92,7 @@ class TD3Policy(BasePolicy): :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments - to pass to the feature extractor. + to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -100,6 +100,8 @@ class TD3Policy(BasePolicy): :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) """ def __init__( @@ -115,6 +117,7 @@ class TD3Policy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, + share_features_extractor: bool = True, ): super(TD3Policy, self).__init__( observation_space, @@ -135,35 +138,54 @@ class TD3Policy(BasePolicy): actor_arch, critic_arch = get_actor_critic_arch(net_arch) - self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - self.features_dim = self.features_extractor.features_dim - self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { "observation_space": self.observation_space, "action_space": self.action_space, - "features_extractor": self.features_extractor, - "features_dim": self.features_dim, "net_arch": actor_arch, "activation_fn": self.activation_fn, "normalize_images": normalize_images, } self.actor_kwargs = self.net_args.copy() self.critic_kwargs = self.net_args.copy() - self.critic_kwargs.update({"n_critics": n_critics, "net_arch": critic_arch}) + self.critic_kwargs.update( + { + "n_critics": n_critics, + "net_arch": critic_arch, + "share_features_extractor": share_features_extractor, + } + ) + self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None + self.share_features_extractor = share_features_extractor self._build(lr_schedule) def _build(self, lr_schedule: Callable) -> None: - self.actor = self.make_actor() - self.actor_target = self.make_actor() + # Create actor and target + # the features extractor should not be shared + self.actor = self.make_actor(features_extractor=None) + self.actor_target = self.make_actor(features_extractor=None) + # Initialize the target to have the same weights as the actor self.actor_target.load_state_dict(self.actor.state_dict()) + self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) - self.critic = self.make_critic() - self.critic_target = self.make_critic() + + if self.share_features_extractor: + self.critic = self.make_critic(features_extractor=self.actor.features_extractor) + # Critic target should not share the features extactor with critic + # but it can share it with the actor target as actor and critic are sharing + # the same features_extractor too + # NOTE: as a result the effective poliak (soft-copy) coefficient for the features extractor + # will be 2 * tau instead of tau (updated one time with the actor, a second time with the critic) + self.critic_target = self.make_critic(features_extractor=self.actor_target.features_extractor) + else: + # Create new features extractor for each network + self.critic = self.make_critic(features_extractor=None) + self.critic_target = self.make_critic(features_extractor=None) + self.critic_target.load_state_dict(self.critic.state_dict()) self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) @@ -180,15 +202,18 @@ class TD3Policy(BasePolicy): optimizer_kwargs=self.optimizer_kwargs, features_extractor_class=self.features_extractor_class, features_extractor_kwargs=self.features_extractor_kwargs, + share_features_extractor=self.share_features_extractor, ) ) return data - def make_actor(self) -> Actor: - return Actor(**self.actor_kwargs).to(self.device) + def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: + actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) + return Actor(**actor_kwargs).to(self.device) - def make_critic(self) -> ContinuousCritic: - return ContinuousCritic(**self.critic_kwargs).to(self.device) + def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic: + critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) + return ContinuousCritic(**critic_kwargs).to(self.device) def forward(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: return self._predict(observation, deterministic=deterministic) @@ -211,7 +236,7 @@ class CnnPolicy(TD3Policy): :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments - to pass to the feature extractor. + to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -219,6 +244,8 @@ class CnnPolicy(TD3Policy): :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) """ def __init__( @@ -234,6 +261,7 @@ class CnnPolicy(TD3Policy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, + share_features_extractor: bool = True, ): super(CnnPolicy, self).__init__( observation_space, @@ -247,6 +275,7 @@ class CnnPolicy(TD3Policy): optimizer_class, optimizer_kwargs, n_critics, + share_features_extractor, ) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 58c80b6..1c1ef35 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -1,10 +1,13 @@ import os +from copy import deepcopy import numpy as np import pytest +import torch as th from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.identity_env import FakeImageEnv +from stable_baselines3.common.utils import zip_strict @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN]) @@ -35,3 +38,139 @@ def test_cnn(tmp_path, model_class): assert np.allclose(action, model.predict(obs, deterministic=True)[0]) os.remove(str(tmp_path / SAVE_NAME)) + + +def patch_dqn_names_(model): + # Small hack to make the test work with DQN + if isinstance(model, DQN): + model.critic = model.q_net + model.critic_target = model.q_net_target + + +def params_should_match(params, other_params): + for param, other_param in zip_strict(params, other_params): + assert th.allclose(param, other_param) + + +def params_should_differ(params, other_params): + for param, other_param in zip_strict(params, other_params): + assert not th.allclose(param, other_param) + + +def check_td3_feature_extractor_match(model): + for (key, actor_param), critic_param in zip(model.actor_target.named_parameters(), model.critic_target.parameters()): + if "features_extractor" in key: + assert th.allclose(actor_param, critic_param), key + + +def check_td3_feature_extractor_differ(model): + for (key, actor_param), critic_param in zip(model.actor_target.named_parameters(), model.critic_target.parameters()): + if "features_extractor" in key: + assert not th.allclose(actor_param, critic_param), key + + +@pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) +@pytest.mark.parametrize("share_features_extractor", [True, False]) +def test_features_extractor_target_net(model_class, share_features_extractor): + if model_class == DQN and share_features_extractor: + pytest.skip() + + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3}) + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs = dict(buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) + if model_class != DQN: + kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor + + model = model_class("CnnPolicy", env, seed=0, **kwargs) + + patch_dqn_names_(model) + + if share_features_extractor: + # Check that the objects are the same and not just copied + assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor) + if model_class == TD3: + assert id(model.policy.actor_target.features_extractor) == id(model.policy.critic_target.features_extractor) + # Actor and critic feature extractor should be the same + td3_features_extractor_check = check_td3_feature_extractor_match + else: + # Actor and critic feature extractor should differ same + td3_features_extractor_check = check_td3_feature_extractor_differ + # Check that the object differ + if model_class != DQN: + assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor) + + if model_class == TD3: + assert id(model.policy.actor_target.features_extractor) != id(model.policy.critic_target.features_extractor) + + # Critic and target should be equal at the begginning of training + params_should_match(model.critic.parameters(), model.critic_target.parameters()) + + # TD3 has also a target actor net + if model_class == TD3: + params_should_match(model.actor.parameters(), model.actor_target.parameters()) + + model.learn(200) + + # Critic and target should differ + params_should_differ(model.critic.parameters(), model.critic_target.parameters()) + + if model_class == TD3: + params_should_differ(model.actor.parameters(), model.actor_target.parameters()) + td3_features_extractor_check(model) + + # Re-initialize and collect some random data (without doing gradient steps, + # since 10 < learning_starts = 100) + model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10) + + patch_dqn_names_(model) + + original_param = deepcopy(list(model.critic.parameters())) + original_target_param = deepcopy(list(model.critic_target.parameters())) + if model_class == TD3: + original_actor_target_param = deepcopy(list(model.actor_target.parameters())) + + # Deactivate copy to target + model.tau = 0.0 + model.train(gradient_steps=1) + + # Target should be the same + params_should_match(original_target_param, model.critic_target.parameters()) + + if model_class == TD3: + params_should_match(original_actor_target_param, model.actor_target.parameters()) + td3_features_extractor_check(model) + + # not the same for critic net (updated by gradient descent) + params_should_differ(original_param, model.critic.parameters()) + + # Update the reference as it should not change in the next step + original_param = deepcopy(list(model.critic.parameters())) + + if model_class == TD3: + original_actor_param = deepcopy(list(model.actor.parameters())) + + # Deactivate learning rate + model.lr_schedule = lambda _: 0.0 + # Re-activate polyak update + model.tau = 0.01 + # Special case for DQN: target net is updated in the `collect_rollout()` + # not the `train()` method + if model_class == DQN: + model.target_update_interval = 1 + model._on_step() + + model.train(gradient_steps=1) + + # Target should have changed now (due to polyak update) + params_should_differ(original_target_param, model.critic_target.parameters()) + + # Critic should be the same + params_should_match(original_param, model.critic.parameters()) + + if model_class == TD3: + params_should_differ(original_actor_target_param, model.actor_target.parameters()) + + params_should_match(original_actor_param, model.actor.parameters()) + + td3_features_extractor_check(model) diff --git a/tests/test_run.py b/tests/test_run.py index 5f5f13f..fae6782 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -20,9 +20,10 @@ def test_deterministic_pg(model_class, action_noise): learning_starts=100, verbose=1, create_eval_env=True, + buffer_size=250, action_noise=action_noise, ) - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=300, eval_freq=250) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) @@ -50,6 +51,7 @@ def test_ppo(env_id, clip_range_vf): model = PPO( "MlpPolicy", env_id, + n_steps=512, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, @@ -68,19 +70,25 @@ def test_sac(ent_coef): learning_starts=100, verbose=1, create_eval_env=True, + buffer_size=250, ent_coef=ent_coef, action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)), ) - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=300, eval_freq=250) @pytest.mark.parametrize("n_critics", [1, 3]) def test_n_critics(n_critics): # Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG model = SAC( - "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1 + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), + learning_starts=100, + buffer_size=10000, + verbose=1, ) - model.learn(total_timesteps=500) + model.learn(total_timesteps=300) def test_dqn(): diff --git a/tests/test_save_load.py b/tests/test_save_load.py index e6230eb..3787ebc 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -303,21 +303,23 @@ def test_save_load_policy(tmp_path, model_class, policy_str): :param model_class: (BaseAlgorithm) A RL model :param policy_str: (str) Name of the policy. """ - kwargs = {} + kwargs = dict(policy_kwargs=dict(net_arch=[16])) if policy_str == "MlpPolicy": env = select_env(model_class) else: if model_class in [SAC, TD3, DQN, DDPG]: # Avoid memory error when using replay buffer # Reduce the size of the features - kwargs = dict(buffer_size=250) + kwargs = dict( + buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)) + ) env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN) env = DummyVecEnv([lambda: env]) # create model - model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) - model.learn(total_timesteps=500) + model = model_class(policy_str, env, verbose=1, **kwargs) + model.learn(total_timesteps=300) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) diff --git a/tests/test_utils.py b/tests/test_utils.py index b3555d8..c30cb98 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -195,4 +195,4 @@ def test_zip_strict(): def test_cmd_util_rename(): """Test that importing cmd_util still works but raises warning""" with pytest.warns(FutureWarning): - from stable_baselines3.common.cmd_util import make_vec_env + from stable_baselines3.common.cmd_util import make_vec_env # noqa: F401