diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 620ec1c..91d110e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,16 +4,19 @@ Changelog ========== -Release 1.2.1a1 (WIP) +Release 1.2.1a2 (WIP) --------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- ``sde_net_arch`` argument in policies is deprecated and will be removed in a future version. +- ``_get_latent`` (``ActorCriticPolicy``) was removed New Features: ^^^^^^^^^^^^^ - Added methods ``get_distribution`` and ``predict_values`` for ``ActorCriticPolicy`` for A2C/PPO/TRPO (@cyprienc) +- Added methods ``forward_actor`` and ``forward_critic`` for ``MlpExtractor`` Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 48d9a5e..9b45592 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -2,6 +2,7 @@ import collections import copy +import warnings from abc import ABC, abstractmethod from functools import partial from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -171,6 +172,15 @@ class BaseModel(nn.Module, ABC): """ device = get_device(device) saved_variables = th.load(path, map_location=device) + + # Allow to load policy saved with older version of SB3 + if "sde_net_arch" in saved_variables["data"]: + warnings.warn( + "sde_net_arch is deprecated, please downgrade to SB3 v1.2.0 if you need such parameter.", + DeprecationWarning, + ) + del saved_variables["data"]["sde_net_arch"] + # Create policy object model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable # Load weights @@ -458,11 +468,12 @@ class ActorCriticPolicy(BasePolicy): "full_std": full_std, "squash_output": squash_output, "use_expln": use_expln, - "learn_features": sde_net_arch is not None, + "learn_features": False, } - self.sde_features_extractor = None - self.sde_net_arch = sde_net_arch + if sde_net_arch is not None: + warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + self.use_sde = use_sde self.dist_kwargs = dist_kwargs @@ -484,7 +495,6 @@ class ActorCriticPolicy(BasePolicy): log_std_init=self.log_std_init, squash_output=default_none_kwargs["squash_output"], full_std=default_none_kwargs["full_std"], - sde_net_arch=self.sde_net_arch, use_expln=default_none_kwargs["use_expln"], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone ortho_init=self.ortho_init, @@ -531,26 +541,15 @@ class ActorCriticPolicy(BasePolicy): latent_dim_pi = self.mlp_extractor.latent_dim_pi - # Separate features extractor for gSDE - if self.sde_net_arch is not None: - self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( - self.features_dim, self.sde_net_arch, self.activation_fn - ) - if isinstance(self.action_dist, DiagGaussianDistribution): self.action_net, self.log_std = self.action_dist.proba_distribution_net( latent_dim=latent_dim_pi, log_std_init=self.log_std_init ) elif isinstance(self.action_dist, StateDependentNoiseDistribution): - latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim self.action_net, self.log_std = self.action_dist.proba_distribution_net( - latent_dim=latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init + latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init ) - elif isinstance(self.action_dist, CategoricalDistribution): - self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) - elif isinstance(self.action_dist, MultiCategoricalDistribution): - self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) - elif isinstance(self.action_dist, BernoulliDistribution): + elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) else: raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") @@ -583,39 +582,21 @@ class ActorCriticPolicy(BasePolicy): :param deterministic: Whether to sample or use deterministic actions :return: action, value and log probability of the action """ - latent_pi, latent_vf, latent_sde = self._get_latent(obs) + # Preprocess the observation if needed + features = self.extract_features(obs) + latent_pi, latent_vf = self.mlp_extractor(features) # Evaluate the values for the given observations values = self.value_net(latent_vf) - distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde) + distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) return actions, values, log_prob - def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: - """ - Get the latent code (i.e., activations of the last layer of each network) - for the different networks. - - :param obs: Observation - :return: Latent codes - for the actor, the value function and for gSDE function - """ - # Preprocess the observation if needed - features = self.extract_features(obs) - latent_pi, latent_vf = self.mlp_extractor(features) - - # Features for sde - latent_sde = latent_pi - if self.sde_features_extractor is not None: - latent_sde = self.sde_features_extractor(features) - return latent_pi, latent_vf, latent_sde - - def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution: + def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution: """ Retrieve action distribution given the latent codes. :param latent_pi: Latent code for the actor - :param latent_sde: Latent code for the gSDE exploration function :return: Action distribution """ mean_actions = self.action_net(latent_pi) @@ -632,7 +613,7 @@ class ActorCriticPolicy(BasePolicy): # Here mean_actions are the logits (before rounding to get the binary actions) return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, StateDependentNoiseDistribution): - return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) else: raise ValueError("Invalid action distribution") @@ -644,9 +625,7 @@ class ActorCriticPolicy(BasePolicy): :param deterministic: Whether to use stochastic or deterministic actions :return: Taken action according to the policy """ - latent_pi, _, latent_sde = self._get_latent(observation) - distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) - return distribution.get_actions(deterministic=deterministic) + return self.get_distribution(observation).get_actions(deterministic=deterministic) def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ @@ -658,8 +637,10 @@ class ActorCriticPolicy(BasePolicy): :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ - latent_pi, latent_vf, latent_sde = self._get_latent(obs) - distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) + # Preprocess the observation if needed + features = self.extract_features(obs) + latent_pi, latent_vf = self.mlp_extractor(features) + distribution = self._get_action_dist_from_latent(latent_pi) log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) return values, log_prob, distribution.entropy() @@ -671,8 +652,9 @@ class ActorCriticPolicy(BasePolicy): :param obs: :return: the action distribution. """ - latent_pi, _, latent_sde = self._get_latent(obs) - return self._get_action_dist_from_latent(latent_pi, latent_sde) + features = self.extract_features(obs) + latent_pi = self.mlp_extractor.forward_actor(features) + return self._get_action_dist_from_latent(latent_pi) def predict_values(self, obs: th.Tensor) -> th.Tensor: """ @@ -681,7 +663,8 @@ class ActorCriticPolicy(BasePolicy): :param obs: :return: the estimated values. """ - _, latent_vf, _ = self._get_latent(obs) + features = self.extract_features(obs) + latent_vf = self.mlp_extractor.forward_critic(features) return self.value_net(latent_vf) @@ -911,27 +894,6 @@ class ContinuousCritic(BaseModel): return self.q_networks[0](th.cat([features, actions], dim=1)) -def create_sde_features_extractor( - features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module] -) -> Tuple[nn.Sequential, int]: - """ - Create the neural network that will be used to extract features - for the gSDE exploration function. - - :param features_dim: - :param sde_net_arch: - :param activation_fn: - :return: - """ - # Special case: when using states as features (i.e. sde_net_arch is an empty list) - # don't use any activation function - sde_activation = activation_fn if len(sde_net_arch) > 0 else None - latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False) - latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim - sde_features_extractor = nn.Sequential(*latent_sde_net) - return sde_features_extractor, latent_sde_dim - - _policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 71b647b..0644755 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -227,6 +227,12 @@ class MlpExtractor(nn.Module): shared_latent = self.shared_net(features) return self.policy_net(shared_latent), self.value_net(shared_latent) + def forward_actor(self, features: th.Tensor) -> th.Tensor: + return self.policy_net(self.shared_net(features)) + + def forward_critic(self, features: th.Tensor) -> th.Tensor: + return self.value_net(self.shared_net(features)) + class CombinedExtractor(BaseFeaturesExtractor): """ diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 8bb67df..68133d1 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple, Type, Union import gym @@ -5,7 +6,7 @@ import torch as th from torch import nn from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution -from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -75,7 +76,6 @@ class Actor(BasePolicy): # Save arguments to re-create object at loading self.use_sde = use_sde self.sde_features_extractor = None - self.sde_net_arch = sde_net_arch self.net_arch = net_arch self.features_dim = features_dim self.activation_fn = activation_fn @@ -85,24 +85,20 @@ class Actor(BasePolicy): self.full_std = full_std self.clip_mean = clip_mean + if sde_net_arch is not None: + warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + action_dim = get_action_dim(self.action_space) latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn) self.latent_pi = nn.Sequential(*latent_pi_net) last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim if self.use_sde: - latent_sde_dim = last_layer_dim - # Separate features extractor for gSDE - if sde_net_arch is not None: - self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( - features_dim, sde_net_arch, activation_fn - ) - self.action_dist = StateDependentNoiseDistribution( action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True ) self.mu, self.log_std = self.action_dist.proba_distribution_net( - latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init + latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init ) # Avoid numerical issues by limiting the mean of the Gaussian # to be in [-clip_mean, clip_mean] @@ -124,7 +120,6 @@ class Actor(BasePolicy): use_sde=self.use_sde, log_std_init=self.log_std_init, full_std=self.full_std, - sde_net_arch=self.sde_net_arch, use_expln=self.use_expln, features_extractor=self.features_extractor, clip_mean=self.clip_mean, @@ -169,10 +164,7 @@ class Actor(BasePolicy): mean_actions = self.mu(latent_pi) if self.use_sde: - latent_sde = latent_pi - if self.sde_features_extractor is not None: - latent_sde = self.sde_features_extractor(features) - return mean_actions, self.log_std, dict(latent_sde=latent_sde) + return mean_actions, self.log_std, dict(latent_sde=latent_pi) # Unstructured exploration (Original implementation) log_std = self.log_std(latent_pi) # Original Implementation to cap the standard deviation @@ -273,10 +265,13 @@ class SACPolicy(BasePolicy): "normalize_images": normalize_images, } self.actor_kwargs = self.net_args.copy() + + if sde_net_arch is not None: + warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + sde_kwargs = { "use_sde": use_sde, "log_std_init": log_std_init, - "sde_net_arch": sde_net_arch, "use_expln": use_expln, "clip_mean": clip_mean, } @@ -329,7 +324,6 @@ class SACPolicy(BasePolicy): activation_fn=self.net_args["activation_fn"], use_sde=self.actor_kwargs["use_sde"], log_std_init=self.actor_kwargs["log_std_init"], - sde_net_arch=self.actor_kwargs["sde_net_arch"], use_expln=self.actor_kwargs["use_expln"], clip_mean=self.actor_kwargs["clip_mean"], n_critics=self.critic_kwargs["n_critics"], diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index b8ef03e..c4baa5c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.2.1a1 +1.2.1a2 diff --git a/tests/test_sde.py b/tests/test_sde.py index 74853a0..e20b01d 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -60,9 +60,9 @@ def test_sde_check(): @pytest.mark.parametrize("model_class", [SAC, A2C, PPO]) -@pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []]) @pytest.mark.parametrize("use_expln", [False, True]) -def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln): +def test_state_dependent_noise(model_class, use_expln): + kwargs = {"learning_starts": 0} if model_class == SAC else {"n_steps": 64} model = model_class( "MlpPolicy", "Pendulum-v0", @@ -70,9 +70,10 @@ def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln): seed=None, create_eval_env=True, verbose=1, - policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch, use_expln=use_expln, net_arch=[64]), + policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]), + **kwargs, ) - model.learn(total_timesteps=int(300), eval_freq=250) + model.learn(total_timesteps=255, eval_freq=250) model.policy.reset_noise() if model_class == SAC: model.policy.actor.get_std()