mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Refactor: enable sde net arch for TD3 and SAC
This commit is contained in:
parent
a2a8bbdf11
commit
4e39a0627c
4 changed files with 141 additions and 48 deletions
|
|
@ -93,9 +93,30 @@ def create_mlp(input_dim, output_dim, net_arch,
|
|||
return modules
|
||||
|
||||
|
||||
class BaseNetwork(nn.Module):
|
||||
"""docstring for BaseNetwork."""
|
||||
def create_sde_feature_extractor(features_dim, sde_net_arch, activation_fn):
|
||||
"""
|
||||
Create the neural network that will be used to extract features
|
||||
for the SDE.
|
||||
|
||||
:param features_dim: (int)
|
||||
:param sde_net_arch: ([int])
|
||||
:param activation_fn: (nn.Module)
|
||||
:return: (nn.Sequential, int)
|
||||
"""
|
||||
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
|
||||
# don't use any activation function
|
||||
sde_activation = activation_fn if len(sde_net_arch) > 0 else None
|
||||
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_out=False)
|
||||
latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim
|
||||
sde_feature_extractor = nn.Sequential(*latent_sde_net)
|
||||
return sde_feature_extractor, latent_sde_dim
|
||||
|
||||
|
||||
class BaseNetwork(nn.Module):
|
||||
"""
|
||||
Abstract class for the different networks (actor/critic)
|
||||
that implements two helpers for using CEM with their weights.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BaseNetwork, self).__init__()
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ class PPOPolicy(BasePolicy):
|
|||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using SDE
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
"""
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
|
|
@ -34,7 +37,6 @@ class PPOPolicy(BasePolicy):
|
|||
super(PPOPolicy, self).__init__(observation_space, action_space, device)
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
if net_arch is None:
|
||||
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
|
||||
|
|
@ -87,13 +89,8 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
# Separate feature extractor for SDE
|
||||
if self.sde_net_arch is not None:
|
||||
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
|
||||
# don't use any activation function
|
||||
sde_activation = self.activation_fn if len(self.sde_net_arch) > 0 else None
|
||||
latent_sde = create_mlp(self.features_dim, -1, self.sde_net_arch,
|
||||
activation_fn=sde_activation, squash_out=False)
|
||||
self.sde_feature_extractor = nn.Sequential(*latent_sde)
|
||||
latent_sde_dim = self.sde_net_arch[-1] if len(self.sde_net_arch) > 0 else self.features_dim
|
||||
self.sde_feature_extractor, latent_sde_dim = create_sde_feature_extractor(self.features_dim, self.sde_net_arch,
|
||||
self.activation_fn)
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, create_sde_feature_extractor
|
||||
from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
|
|
@ -21,20 +21,30 @@ class Actor(BaseNetwork):
|
|||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using SDE.
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
"""
|
||||
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
|
||||
use_sde=False, log_std_init=-3, full_std=True):
|
||||
use_sde=False, log_std_init=-3, full_std=True, sde_net_arch=None):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
actor_net = create_mlp(obs_dim, -1, net_arch, activation_fn)
|
||||
self.actor_net = nn.Sequential(*actor_net)
|
||||
latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn)
|
||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
self.use_sde = use_sde
|
||||
self.sde_feature_extractor = None
|
||||
|
||||
if self.use_sde:
|
||||
latent_sde_dim = net_arch[-1]
|
||||
# Separate feature extractor for SDE
|
||||
if sde_net_arch is not None:
|
||||
self.sde_feature_extractor, latent_sde_dim = create_sde_feature_extractor(obs_dim, sde_net_arch, activation_fn)
|
||||
|
||||
# TODO: check for the learn_features
|
||||
self.action_dist = StateDependentNoiseDistribution(action_dim, full_std=full_std, use_expln=False,
|
||||
learn_features=True, squash_output=True)
|
||||
self.mu, self.log_std = self.action_dist.proba_distribution_net(latent_dim=net_arch[-1],
|
||||
latent_sde_dim=latent_sde_dim,
|
||||
log_std_init=log_std_init)
|
||||
else:
|
||||
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
|
||||
|
|
@ -61,38 +71,55 @@ class Actor(BaseNetwork):
|
|||
"""
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||
|
||||
def _get_latent(self, obs):
|
||||
latent_pi = self.latent_pi(obs)
|
||||
|
||||
if self.sde_feature_extractor is not None:
|
||||
latent_sde = self.sde_feature_extractor(obs)
|
||||
else:
|
||||
latent_sde = latent_pi
|
||||
return latent_pi, latent_sde
|
||||
|
||||
def get_action_dist_params(self, obs):
|
||||
latent = self.actor_net(obs)
|
||||
latent_pi, latent_sde = self._get_latent(obs)
|
||||
|
||||
if self.use_sde:
|
||||
mean_actions, log_std = self.mu(latent), self.log_std
|
||||
mean_actions, log_std = self.mu(latent_pi), self.log_std
|
||||
else:
|
||||
mean_actions, log_std = self.mu(latent), self.log_std(latent)
|
||||
mean_actions, log_std = self.mu(latent_pi), self.log_std(latent_pi)
|
||||
# Original Implementation to cap the standard deviation
|
||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||
return mean_actions, log_std, latent
|
||||
return mean_actions, log_std, latent_sde
|
||||
|
||||
def forward(self, obs, deterministic=False):
|
||||
mean_actions, log_std, latent = self.get_action_dist_params(obs)
|
||||
mean_actions, log_std, latent_sde = self.get_action_dist_params(obs)
|
||||
if self.use_sde:
|
||||
# Note the action is squashed
|
||||
action, _ = self.action_dist.proba_distribution(mean_actions, log_std, latent, deterministic=deterministic)
|
||||
action, _ = self.action_dist.proba_distribution(mean_actions, log_std, latent_sde, deterministic=deterministic)
|
||||
else:
|
||||
# Note the action is squashed
|
||||
action, _ = self.action_dist.proba_distribution(mean_actions, log_std, deterministic=deterministic)
|
||||
return action
|
||||
|
||||
def action_log_prob(self, obs):
|
||||
mean_actions, log_std, latent = self.get_action_dist_params(obs)
|
||||
mean_actions, log_std, latent_sde = self.get_action_dist_params(obs)
|
||||
|
||||
if self.use_sde:
|
||||
action, log_prob = self.action_dist.log_prob_from_params(mean_actions, self.log_std, latent)
|
||||
action, log_prob = self.action_dist.log_prob_from_params(mean_actions, self.log_std, latent_sde)
|
||||
else:
|
||||
action, log_prob = self.action_dist.log_prob_from_params(mean_actions, log_std)
|
||||
return action, log_prob
|
||||
|
||||
|
||||
class Critic(BaseNetwork):
|
||||
"""
|
||||
Critic network (q-value function) for SAC.
|
||||
|
||||
:param obs_dim: (int) Dimension of the observation
|
||||
:param action_dim: (int) Dimension of the action space
|
||||
:param net_arch: ([int]) Network architecture
|
||||
:param activation_fn: (nn.Module) Activation function
|
||||
"""
|
||||
def __init__(self, obs_dim, action_dim,
|
||||
net_arch, activation_fn=nn.ReLU):
|
||||
super(Critic, self).__init__()
|
||||
|
|
@ -114,9 +141,25 @@ class Critic(BaseNetwork):
|
|||
|
||||
|
||||
class SACPolicy(BasePolicy):
|
||||
"""
|
||||
Policy class (with both actor and critic) for SAC.
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_dim: (gym.spaces.Space) Action space
|
||||
:param learning_rate: (callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param activation_fn: (nn.Module) Activation function
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
"""
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.ReLU, use_sde=False, log_std_init=-3):
|
||||
activation_fn=nn.ReLU, use_sde=False,
|
||||
log_std_init=-3, sde_net_arch=None):
|
||||
super(SACPolicy, self).__init__(observation_space, action_space, device)
|
||||
|
||||
if net_arch is None:
|
||||
|
|
@ -133,8 +176,12 @@ class SACPolicy(BasePolicy):
|
|||
'activation_fn': self.activation_fn
|
||||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
self.actor_kwargs['use_sde'] = use_sde
|
||||
self.actor_kwargs['log_std_init'] = log_std_init
|
||||
sde_kwargs = {
|
||||
'use_sde': use_sde,
|
||||
'log_std_init': log_std_init,
|
||||
'sde_net_arch': sde_net_arch
|
||||
}
|
||||
self.actor_kwargs.update(sde_kwargs)
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, create_sde_feature_extractor
|
||||
from torchy_baselines.common.distributions import StateDependentNoiseDistribution
|
||||
|
||||
|
||||
|
|
@ -19,10 +19,13 @@ class Actor(BaseNetwork):
|
|||
:param lr_sde: (float) Learning rate for the standard deviation of the noise
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using SDE.
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
"""
|
||||
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
|
||||
use_sde=False, log_std_init=-3, clip_noise=None,
|
||||
lr_sde=3e-4, full_std=False):
|
||||
lr_sde=3e-4, full_std=False, sde_net_arch=None):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
self.latent_pi, self.log_std = None, None
|
||||
|
|
@ -30,23 +33,32 @@ class Actor(BaseNetwork):
|
|||
self.use_sde, self.sde_optimizer = use_sde, None
|
||||
self.action_dim = action_dim
|
||||
self.full_std = full_std
|
||||
self.sde_feature_extractor = None
|
||||
|
||||
if use_sde:
|
||||
latent_pi = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
|
||||
self.latent_pi = nn.Sequential(*latent_pi)
|
||||
latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
|
||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
latent_sde_dim = net_arch[-1]
|
||||
learn_features = sde_net_arch is not None
|
||||
|
||||
# Separate feature extractor for SDE
|
||||
if sde_net_arch is not None:
|
||||
self.sde_feature_extractor, latent_sde_dim = create_sde_feature_extractor(obs_dim, sde_net_arch, activation_fn)
|
||||
|
||||
# Create state dependent noise matrix (SDE)
|
||||
self.action_dist = StateDependentNoiseDistribution(action_dim, full_std=full_std, use_expln=False,
|
||||
squash_output=False)
|
||||
squash_output=False, learn_features=learn_features)
|
||||
action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=net_arch[-1],
|
||||
latent_sde_dim=latent_sde_dim,
|
||||
log_std_init=log_std_init)
|
||||
# Squash output
|
||||
self.actor_net = nn.Sequential(action_net, nn.Tanh())
|
||||
self.mu = nn.Sequential(action_net, nn.Tanh())
|
||||
self.clip_noise = clip_noise
|
||||
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
|
||||
self.reset_noise()
|
||||
else:
|
||||
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
|
||||
self.actor_net = nn.Sequential(*actor_net)
|
||||
self.mu = nn.Sequential(*actor_net)
|
||||
|
||||
def get_std(self):
|
||||
"""
|
||||
|
|
@ -60,9 +72,18 @@ class Actor(BaseNetwork):
|
|||
"""
|
||||
return self.action_dist.get_std(self.log_std)
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi):
|
||||
mean_actions = self.actor_net(latent_pi)
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
||||
def _get_action_dist_from_latent(self, latent_pi, latent_sde):
|
||||
mean_actions = self.mu(latent_pi)
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
|
||||
|
||||
def _get_latent(self, obs):
|
||||
latent_pi = self.latent_pi(obs)
|
||||
|
||||
if self.sde_feature_extractor is not None:
|
||||
latent_sde = self.sde_feature_extractor(obs)
|
||||
else:
|
||||
latent_sde = latent_pi
|
||||
return latent_pi, latent_sde
|
||||
|
||||
def evaluate_actions(self, obs, action):
|
||||
"""
|
||||
|
|
@ -75,9 +96,8 @@ class Actor(BaseNetwork):
|
|||
:return: (th.Tensor, th.Tensor) log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
with th.no_grad():
|
||||
latent_pi = self.latent_pi(obs)
|
||||
_, distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
latent_pi, latent_sde = self._get_latent(obs)
|
||||
_, distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
|
||||
log_prob = distribution.log_prob(action)
|
||||
# value = self.value_net(latent_vf)
|
||||
return log_prob, distribution.entropy()
|
||||
|
|
@ -90,20 +110,21 @@ class Actor(BaseNetwork):
|
|||
|
||||
def forward(self, obs, deterministic=True):
|
||||
if self.use_sde:
|
||||
latent_pi = self.latent_pi(obs)
|
||||
latent_pi, latent_sde = self._get_latent(obs)
|
||||
if deterministic:
|
||||
return self.actor_net(latent_pi)
|
||||
noise = self.action_dist.get_noise(latent_pi)
|
||||
return self.mu(latent_pi)
|
||||
|
||||
noise = self.action_dist.get_noise(latent_sde)
|
||||
if self.clip_noise is not None:
|
||||
noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
|
||||
# TODO: Replace with squashing -> need to account for that in the sde update
|
||||
# -> set squash_out=True in the action_dist?
|
||||
# NOTE: the clipping is done in the rollout for now
|
||||
return self.actor_net(latent_pi) + noise
|
||||
return self.mu(latent_pi) + noise
|
||||
# action, _ = self._get_action_dist_from_latent(latent_pi)
|
||||
# return action
|
||||
else:
|
||||
return self.actor_net(obs)
|
||||
return self.mu(obs)
|
||||
|
||||
|
||||
class Critic(BaseNetwork):
|
||||
|
|
@ -148,11 +169,14 @@ class TD3Policy(BasePolicy):
|
|||
:param activation_fn: (nn.Module) Activation function
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
"""
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.ReLU, use_sde=False, log_std_init=-3,
|
||||
clip_noise=None, lr_sde=3e-4):
|
||||
clip_noise=None, lr_sde=3e-4, sde_net_arch=None):
|
||||
super(TD3Policy, self).__init__(observation_space, action_space, device)
|
||||
|
||||
# Default network architecture, from the original paper
|
||||
|
|
@ -170,10 +194,14 @@ class TD3Policy(BasePolicy):
|
|||
'activation_fn': self.activation_fn
|
||||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
self.actor_kwargs['use_sde'] = use_sde
|
||||
self.actor_kwargs['log_std_init'] = log_std_init
|
||||
self.actor_kwargs['clip_noise'] = clip_noise
|
||||
self.actor_kwargs['lr_sde'] = lr_sde
|
||||
sde_kwargs = {
|
||||
'use_sde': use_sde,
|
||||
'log_std_init': log_std_init,
|
||||
'clip_noise': clip_noise,
|
||||
'lr_sde': lr_sde,
|
||||
'sde_net_arch': sde_net_arch
|
||||
}
|
||||
self.actor_kwargs.update(sde_kwargs)
|
||||
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
|
|
|
|||
Loading…
Reference in a new issue