mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Remove sde_net_arch + Simplify policy (#584)
* Remove `sde_net_arch` + Simplify policy * Add warning at load time
This commit is contained in:
parent
89af49ca91
commit
201fbffa8c
6 changed files with 59 additions and 93 deletions
|
|
@ -4,16 +4,19 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.2.1a1 (WIP)
|
||||
Release 1.2.1a2 (WIP)
|
||||
---------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- ``sde_net_arch`` argument in policies is deprecated and will be removed in a future version.
|
||||
- ``_get_latent`` (``ActorCriticPolicy``) was removed
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added methods ``get_distribution`` and ``predict_values`` for ``ActorCriticPolicy`` for A2C/PPO/TRPO (@cyprienc)
|
||||
- Added methods ``forward_actor`` and ``forward_critic`` for ``MlpExtractor``
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import collections
|
||||
import copy
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
|
@ -171,6 +172,15 @@ class BaseModel(nn.Module, ABC):
|
|||
"""
|
||||
device = get_device(device)
|
||||
saved_variables = th.load(path, map_location=device)
|
||||
|
||||
# Allow to load policy saved with older version of SB3
|
||||
if "sde_net_arch" in saved_variables["data"]:
|
||||
warnings.warn(
|
||||
"sde_net_arch is deprecated, please downgrade to SB3 v1.2.0 if you need such parameter.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
del saved_variables["data"]["sde_net_arch"]
|
||||
|
||||
# Create policy object
|
||||
model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
|
||||
# Load weights
|
||||
|
|
@ -458,11 +468,12 @@ class ActorCriticPolicy(BasePolicy):
|
|||
"full_std": full_std,
|
||||
"squash_output": squash_output,
|
||||
"use_expln": use_expln,
|
||||
"learn_features": sde_net_arch is not None,
|
||||
"learn_features": False,
|
||||
}
|
||||
|
||||
self.sde_features_extractor = None
|
||||
self.sde_net_arch = sde_net_arch
|
||||
if sde_net_arch is not None:
|
||||
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||
|
||||
self.use_sde = use_sde
|
||||
self.dist_kwargs = dist_kwargs
|
||||
|
||||
|
|
@ -484,7 +495,6 @@ class ActorCriticPolicy(BasePolicy):
|
|||
log_std_init=self.log_std_init,
|
||||
squash_output=default_none_kwargs["squash_output"],
|
||||
full_std=default_none_kwargs["full_std"],
|
||||
sde_net_arch=self.sde_net_arch,
|
||||
use_expln=default_none_kwargs["use_expln"],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
ortho_init=self.ortho_init,
|
||||
|
|
@ -531,26 +541,15 @@ class ActorCriticPolicy(BasePolicy):
|
|||
|
||||
latent_dim_pi = self.mlp_extractor.latent_dim_pi
|
||||
|
||||
# Separate features extractor for gSDE
|
||||
if self.sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
|
||||
self.features_dim, self.sde_net_arch, self.activation_fn
|
||||
)
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
|
||||
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
|
||||
)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
|
||||
latent_dim=latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init
|
||||
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
|
||||
)
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
elif isinstance(self.action_dist, BernoulliDistribution):
|
||||
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
|
||||
|
|
@ -583,39 +582,21 @@ class ActorCriticPolicy(BasePolicy):
|
|||
:param deterministic: Whether to sample or use deterministic actions
|
||||
:return: action, value and log probability of the action
|
||||
"""
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
# Evaluate the values for the given observations
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob
|
||||
|
||||
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Get the latent code (i.e., activations of the last layer of each network)
|
||||
for the different networks.
|
||||
|
||||
:param obs: Observation
|
||||
:return: Latent codes
|
||||
for the actor, the value function and for gSDE function
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
|
||||
# Features for sde
|
||||
latent_sde = latent_pi
|
||||
if self.sde_features_extractor is not None:
|
||||
latent_sde = self.sde_features_extractor(features)
|
||||
return latent_pi, latent_vf, latent_sde
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution:
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
|
||||
"""
|
||||
Retrieve action distribution given the latent codes.
|
||||
|
||||
:param latent_pi: Latent code for the actor
|
||||
:param latent_sde: Latent code for the gSDE exploration function
|
||||
:return: Action distribution
|
||||
"""
|
||||
mean_actions = self.action_net(latent_pi)
|
||||
|
|
@ -632,7 +613,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
# Here mean_actions are the logits (before rounding to get the binary actions)
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
||||
else:
|
||||
raise ValueError("Invalid action distribution")
|
||||
|
||||
|
|
@ -644,9 +625,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
:param deterministic: Whether to use stochastic or deterministic actions
|
||||
:return: Taken action according to the policy
|
||||
"""
|
||||
latent_pi, _, latent_sde = self._get_latent(observation)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
|
||||
return distribution.get_actions(deterministic=deterministic)
|
||||
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
||||
|
||||
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
|
|
@ -658,8 +637,10 @@ class ActorCriticPolicy(BasePolicy):
|
|||
:return: estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
values = self.value_net(latent_vf)
|
||||
return values, log_prob, distribution.entropy()
|
||||
|
|
@ -671,8 +652,9 @@ class ActorCriticPolicy(BasePolicy):
|
|||
:param obs:
|
||||
:return: the action distribution.
|
||||
"""
|
||||
latent_pi, _, latent_sde = self._get_latent(obs)
|
||||
return self._get_action_dist_from_latent(latent_pi, latent_sde)
|
||||
features = self.extract_features(obs)
|
||||
latent_pi = self.mlp_extractor.forward_actor(features)
|
||||
return self._get_action_dist_from_latent(latent_pi)
|
||||
|
||||
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
|
|
@ -681,7 +663,8 @@ class ActorCriticPolicy(BasePolicy):
|
|||
:param obs:
|
||||
:return: the estimated values.
|
||||
"""
|
||||
_, latent_vf, _ = self._get_latent(obs)
|
||||
features = self.extract_features(obs)
|
||||
latent_vf = self.mlp_extractor.forward_critic(features)
|
||||
return self.value_net(latent_vf)
|
||||
|
||||
|
||||
|
|
@ -911,27 +894,6 @@ class ContinuousCritic(BaseModel):
|
|||
return self.q_networks[0](th.cat([features, actions], dim=1))
|
||||
|
||||
|
||||
def create_sde_features_extractor(
|
||||
features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module]
|
||||
) -> Tuple[nn.Sequential, int]:
|
||||
"""
|
||||
Create the neural network that will be used to extract features
|
||||
for the gSDE exploration function.
|
||||
|
||||
:param features_dim:
|
||||
:param sde_net_arch:
|
||||
:param activation_fn:
|
||||
:return:
|
||||
"""
|
||||
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
|
||||
# don't use any activation function
|
||||
sde_activation = activation_fn if len(sde_net_arch) > 0 else None
|
||||
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False)
|
||||
latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim
|
||||
sde_features_extractor = nn.Sequential(*latent_sde_net)
|
||||
return sde_features_extractor, latent_sde_dim
|
||||
|
||||
|
||||
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,12 @@ class MlpExtractor(nn.Module):
|
|||
shared_latent = self.shared_net(features)
|
||||
return self.policy_net(shared_latent), self.value_net(shared_latent)
|
||||
|
||||
def forward_actor(self, features: th.Tensor) -> th.Tensor:
|
||||
return self.policy_net(self.shared_net(features))
|
||||
|
||||
def forward_critic(self, features: th.Tensor) -> th.Tensor:
|
||||
return self.value_net(self.shared_net(features))
|
||||
|
||||
|
||||
class CombinedExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
|
|
@ -5,7 +6,7 @@ import torch as th
|
|||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
|
|
@ -75,7 +76,6 @@ class Actor(BasePolicy):
|
|||
# Save arguments to re-create object at loading
|
||||
self.use_sde = use_sde
|
||||
self.sde_features_extractor = None
|
||||
self.sde_net_arch = sde_net_arch
|
||||
self.net_arch = net_arch
|
||||
self.features_dim = features_dim
|
||||
self.activation_fn = activation_fn
|
||||
|
|
@ -85,24 +85,20 @@ class Actor(BasePolicy):
|
|||
self.full_std = full_std
|
||||
self.clip_mean = clip_mean
|
||||
|
||||
if sde_net_arch is not None:
|
||||
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
|
||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
||||
|
||||
if self.use_sde:
|
||||
latent_sde_dim = last_layer_dim
|
||||
# Separate features extractor for gSDE
|
||||
if sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
|
||||
features_dim, sde_net_arch, activation_fn
|
||||
)
|
||||
|
||||
self.action_dist = StateDependentNoiseDistribution(
|
||||
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
||||
)
|
||||
self.mu, self.log_std = self.action_dist.proba_distribution_net(
|
||||
latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
|
||||
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
|
||||
)
|
||||
# Avoid numerical issues by limiting the mean of the Gaussian
|
||||
# to be in [-clip_mean, clip_mean]
|
||||
|
|
@ -124,7 +120,6 @@ class Actor(BasePolicy):
|
|||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
full_std=self.full_std,
|
||||
sde_net_arch=self.sde_net_arch,
|
||||
use_expln=self.use_expln,
|
||||
features_extractor=self.features_extractor,
|
||||
clip_mean=self.clip_mean,
|
||||
|
|
@ -169,10 +164,7 @@ class Actor(BasePolicy):
|
|||
mean_actions = self.mu(latent_pi)
|
||||
|
||||
if self.use_sde:
|
||||
latent_sde = latent_pi
|
||||
if self.sde_features_extractor is not None:
|
||||
latent_sde = self.sde_features_extractor(features)
|
||||
return mean_actions, self.log_std, dict(latent_sde=latent_sde)
|
||||
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
|
||||
# Unstructured exploration (Original implementation)
|
||||
log_std = self.log_std(latent_pi)
|
||||
# Original Implementation to cap the standard deviation
|
||||
|
|
@ -273,10 +265,13 @@ class SACPolicy(BasePolicy):
|
|||
"normalize_images": normalize_images,
|
||||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
|
||||
if sde_net_arch is not None:
|
||||
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||
|
||||
sde_kwargs = {
|
||||
"use_sde": use_sde,
|
||||
"log_std_init": log_std_init,
|
||||
"sde_net_arch": sde_net_arch,
|
||||
"use_expln": use_expln,
|
||||
"clip_mean": clip_mean,
|
||||
}
|
||||
|
|
@ -329,7 +324,6 @@ class SACPolicy(BasePolicy):
|
|||
activation_fn=self.net_args["activation_fn"],
|
||||
use_sde=self.actor_kwargs["use_sde"],
|
||||
log_std_init=self.actor_kwargs["log_std_init"],
|
||||
sde_net_arch=self.actor_kwargs["sde_net_arch"],
|
||||
use_expln=self.actor_kwargs["use_expln"],
|
||||
clip_mean=self.actor_kwargs["clip_mean"],
|
||||
n_critics=self.critic_kwargs["n_critics"],
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.2.1a1
|
||||
1.2.1a2
|
||||
|
|
|
|||
|
|
@ -60,9 +60,9 @@ def test_sde_check():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, A2C, PPO])
|
||||
@pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []])
|
||||
@pytest.mark.parametrize("use_expln", [False, True])
|
||||
def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln):
|
||||
def test_state_dependent_noise(model_class, use_expln):
|
||||
kwargs = {"learning_starts": 0} if model_class == SAC else {"n_steps": 64}
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
|
|
@ -70,9 +70,10 @@ def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln):
|
|||
seed=None,
|
||||
create_eval_env=True,
|
||||
verbose=1,
|
||||
policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch, use_expln=use_expln, net_arch=[64]),
|
||||
policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]),
|
||||
**kwargs,
|
||||
)
|
||||
model.learn(total_timesteps=int(300), eval_freq=250)
|
||||
model.learn(total_timesteps=255, eval_freq=250)
|
||||
model.policy.reset_noise()
|
||||
if model_class == SAC:
|
||||
model.policy.actor.get_std()
|
||||
|
|
|
|||
Loading…
Reference in a new issue