Remove sde_net_arch + Simplify policy (#584)

* Remove `sde_net_arch` + Simplify policy

* Add warning at load time
This commit is contained in:
Antonin RAFFIN 2021-09-28 21:32:54 +02:00 committed by GitHub
parent 89af49ca91
commit 201fbffa8c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 59 additions and 93 deletions

View file

@ -4,16 +4,19 @@ Changelog
==========
Release 1.2.1a1 (WIP)
Release 1.2.1a2 (WIP)
---------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``sde_net_arch`` argument in policies is deprecated and will be removed in a future version.
- ``_get_latent`` (``ActorCriticPolicy``) was removed
New Features:
^^^^^^^^^^^^^
- Added methods ``get_distribution`` and ``predict_values`` for ``ActorCriticPolicy`` for A2C/PPO/TRPO (@cyprienc)
- Added methods ``forward_actor`` and ``forward_critic`` for ``MlpExtractor``
Bug Fixes:
^^^^^^^^^^

View file

@ -2,6 +2,7 @@
import collections
import copy
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union
@ -171,6 +172,15 @@ class BaseModel(nn.Module, ABC):
"""
device = get_device(device)
saved_variables = th.load(path, map_location=device)
# Allow to load policy saved with older version of SB3
if "sde_net_arch" in saved_variables["data"]:
warnings.warn(
"sde_net_arch is deprecated, please downgrade to SB3 v1.2.0 if you need such parameter.",
DeprecationWarning,
)
del saved_variables["data"]["sde_net_arch"]
# Create policy object
model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
# Load weights
@ -458,11 +468,12 @@ class ActorCriticPolicy(BasePolicy):
"full_std": full_std,
"squash_output": squash_output,
"use_expln": use_expln,
"learn_features": sde_net_arch is not None,
"learn_features": False,
}
self.sde_features_extractor = None
self.sde_net_arch = sde_net_arch
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
self.use_sde = use_sde
self.dist_kwargs = dist_kwargs
@ -484,7 +495,6 @@ class ActorCriticPolicy(BasePolicy):
log_std_init=self.log_std_init,
squash_output=default_none_kwargs["squash_output"],
full_std=default_none_kwargs["full_std"],
sde_net_arch=self.sde_net_arch,
use_expln=default_none_kwargs["use_expln"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
@ -531,26 +541,15 @@ class ActorCriticPolicy(BasePolicy):
latent_dim_pi = self.mlp_extractor.latent_dim_pi
# Separate features extractor for gSDE
if self.sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
self.features_dim, self.sde_net_arch, self.activation_fn
)
if isinstance(self.action_dist, DiagGaussianDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, CategoricalDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, BernoulliDistribution):
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
else:
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
@ -583,39 +582,21 @@ class ActorCriticPolicy(BasePolicy):
:param deterministic: Whether to sample or use deterministic actions
:return: action, value and log probability of the action
"""
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
return actions, values, log_prob
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Get the latent code (i.e., activations of the last layer of each network)
for the different networks.
:param obs: Observation
:return: Latent codes
for the actor, the value function and for gSDE function
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
# Features for sde
latent_sde = latent_pi
if self.sde_features_extractor is not None:
latent_sde = self.sde_features_extractor(features)
return latent_pi, latent_vf, latent_sde
def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution:
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
"""
Retrieve action distribution given the latent codes.
:param latent_pi: Latent code for the actor
:param latent_sde: Latent code for the gSDE exploration function
:return: Action distribution
"""
mean_actions = self.action_net(latent_pi)
@ -632,7 +613,7 @@ class ActorCriticPolicy(BasePolicy):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
else:
raise ValueError("Invalid action distribution")
@ -644,9 +625,7 @@ class ActorCriticPolicy(BasePolicy):
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
latent_pi, _, latent_sde = self._get_latent(observation)
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
return distribution.get_actions(deterministic=deterministic)
return self.get_distribution(observation).get_actions(deterministic=deterministic)
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
@ -658,8 +637,10 @@ class ActorCriticPolicy(BasePolicy):
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()
@ -671,8 +652,9 @@ class ActorCriticPolicy(BasePolicy):
:param obs:
:return: the action distribution.
"""
latent_pi, _, latent_sde = self._get_latent(obs)
return self._get_action_dist_from_latent(latent_pi, latent_sde)
features = self.extract_features(obs)
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)
def predict_values(self, obs: th.Tensor) -> th.Tensor:
"""
@ -681,7 +663,8 @@ class ActorCriticPolicy(BasePolicy):
:param obs:
:return: the estimated values.
"""
_, latent_vf, _ = self._get_latent(obs)
features = self.extract_features(obs)
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)
@ -911,27 +894,6 @@ class ContinuousCritic(BaseModel):
return self.q_networks[0](th.cat([features, actions], dim=1))
def create_sde_features_extractor(
features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module]
) -> Tuple[nn.Sequential, int]:
"""
Create the neural network that will be used to extract features
for the gSDE exploration function.
:param features_dim:
:param sde_net_arch:
:param activation_fn:
:return:
"""
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
# don't use any activation function
sde_activation = activation_fn if len(sde_net_arch) > 0 else None
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False)
latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim
sde_features_extractor = nn.Sequential(*latent_sde_net)
return sde_features_extractor, latent_sde_dim
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]

View file

@ -227,6 +227,12 @@ class MlpExtractor(nn.Module):
shared_latent = self.shared_net(features)
return self.policy_net(shared_latent), self.value_net(shared_latent)
def forward_actor(self, features: th.Tensor) -> th.Tensor:
return self.policy_net(self.shared_net(features))
def forward_critic(self, features: th.Tensor) -> th.Tensor:
return self.value_net(self.shared_net(features))
class CombinedExtractor(BaseFeaturesExtractor):
"""

View file

@ -1,3 +1,4 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
@ -5,7 +6,7 @@ import torch as th
from torch import nn
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -75,7 +76,6 @@ class Actor(BasePolicy):
# Save arguments to re-create object at loading
self.use_sde = use_sde
self.sde_features_extractor = None
self.sde_net_arch = sde_net_arch
self.net_arch = net_arch
self.features_dim = features_dim
self.activation_fn = activation_fn
@ -85,24 +85,20 @@ class Actor(BasePolicy):
self.full_std = full_std
self.clip_mean = clip_mean
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
action_dim = get_action_dim(self.action_space)
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
if self.use_sde:
latent_sde_dim = last_layer_dim
# Separate features extractor for gSDE
if sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
features_dim, sde_net_arch, activation_fn
)
self.action_dist = StateDependentNoiseDistribution(
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
)
self.mu, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
)
# Avoid numerical issues by limiting the mean of the Gaussian
# to be in [-clip_mean, clip_mean]
@ -124,7 +120,6 @@ class Actor(BasePolicy):
use_sde=self.use_sde,
log_std_init=self.log_std_init,
full_std=self.full_std,
sde_net_arch=self.sde_net_arch,
use_expln=self.use_expln,
features_extractor=self.features_extractor,
clip_mean=self.clip_mean,
@ -169,10 +164,7 @@ class Actor(BasePolicy):
mean_actions = self.mu(latent_pi)
if self.use_sde:
latent_sde = latent_pi
if self.sde_features_extractor is not None:
latent_sde = self.sde_features_extractor(features)
return mean_actions, self.log_std, dict(latent_sde=latent_sde)
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
# Unstructured exploration (Original implementation)
log_std = self.log_std(latent_pi)
# Original Implementation to cap the standard deviation
@ -273,10 +265,13 @@ class SACPolicy(BasePolicy):
"normalize_images": normalize_images,
}
self.actor_kwargs = self.net_args.copy()
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
sde_kwargs = {
"use_sde": use_sde,
"log_std_init": log_std_init,
"sde_net_arch": sde_net_arch,
"use_expln": use_expln,
"clip_mean": clip_mean,
}
@ -329,7 +324,6 @@ class SACPolicy(BasePolicy):
activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"],
sde_net_arch=self.actor_kwargs["sde_net_arch"],
use_expln=self.actor_kwargs["use_expln"],
clip_mean=self.actor_kwargs["clip_mean"],
n_critics=self.critic_kwargs["n_critics"],

View file

@ -1 +1 @@
1.2.1a1
1.2.1a2

View file

@ -60,9 +60,9 @@ def test_sde_check():
@pytest.mark.parametrize("model_class", [SAC, A2C, PPO])
@pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []])
@pytest.mark.parametrize("use_expln", [False, True])
def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln):
def test_state_dependent_noise(model_class, use_expln):
kwargs = {"learning_starts": 0} if model_class == SAC else {"n_steps": 64}
model = model_class(
"MlpPolicy",
"Pendulum-v0",
@ -70,9 +70,10 @@ def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln):
seed=None,
create_eval_env=True,
verbose=1,
policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch, use_expln=use_expln, net_arch=[64]),
policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]),
**kwargs,
)
model.learn(total_timesteps=int(300), eval_freq=250)
model.learn(total_timesteps=255, eval_freq=250)
model.policy.reset_noise()
if model_class == SAC:
model.policy.actor.get_std()