Refactor: enable sde net arch for TD3 and SAC

This commit is contained in:
Antonin Raffin 2019-12-02 14:06:17 +01:00
parent a2a8bbdf11
commit 4e39a0627c
4 changed files with 141 additions and 48 deletions

View file

@ -93,9 +93,30 @@ def create_mlp(input_dim, output_dim, net_arch,
return modules
class BaseNetwork(nn.Module):
"""docstring for BaseNetwork."""
def create_sde_feature_extractor(features_dim, sde_net_arch, activation_fn):
"""
Create the neural network that will be used to extract features
for the SDE.
:param features_dim: (int)
:param sde_net_arch: ([int])
:param activation_fn: (nn.Module)
:return: (nn.Sequential, int)
"""
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
# don't use any activation function
sde_activation = activation_fn if len(sde_net_arch) > 0 else None
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_out=False)
latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim
sde_feature_extractor = nn.Sequential(*latent_sde_net)
return sde_feature_extractor, latent_sde_dim
class BaseNetwork(nn.Module):
"""
Abstract class for the different networks (actor/critic)
that implements two helpers for using CEM with their weights.
"""
def __init__(self):
super(BaseNetwork, self).__init__()

View file

@ -25,6 +25,9 @@ class PPOPolicy(BasePolicy):
:param log_std_init: (float) Initial value for the log standard deviation
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using SDE
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
"""
def __init__(self, observation_space, action_space,
learning_rate, net_arch=None, device='cpu',
@ -34,7 +37,6 @@ class PPOPolicy(BasePolicy):
super(PPOPolicy, self).__init__(observation_space, action_space, device)
self.obs_dim = self.observation_space.shape[0]
# Default network architecture, from stable-baselines
if net_arch is None:
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
@ -87,13 +89,8 @@ class PPOPolicy(BasePolicy):
# Separate feature extractor for SDE
if self.sde_net_arch is not None:
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
# don't use any activation function
sde_activation = self.activation_fn if len(self.sde_net_arch) > 0 else None
latent_sde = create_mlp(self.features_dim, -1, self.sde_net_arch,
activation_fn=sde_activation, squash_out=False)
self.sde_feature_extractor = nn.Sequential(*latent_sde)
latent_sde_dim = self.sde_net_arch[-1] if len(self.sde_net_arch) > 0 else self.features_dim
self.sde_feature_extractor, latent_sde_dim = create_sde_feature_extractor(self.features_dim, self.sde_net_arch,
self.activation_fn)
if isinstance(self.action_dist, DiagGaussianDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,

View file

@ -1,7 +1,7 @@
import torch as th
import torch.nn as nn
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, create_sde_feature_extractor
from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
# CAP the standard deviation of the actor
@ -21,20 +21,30 @@ class Actor(BaseNetwork):
:param log_std_init: (float) Initial value for the log standard deviation
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using SDE.
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
"""
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
use_sde=False, log_std_init=-3, full_std=True):
use_sde=False, log_std_init=-3, full_std=True, sde_net_arch=None):
super(Actor, self).__init__()
actor_net = create_mlp(obs_dim, -1, net_arch, activation_fn)
self.actor_net = nn.Sequential(*actor_net)
latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
self.use_sde = use_sde
self.sde_feature_extractor = None
if self.use_sde:
latent_sde_dim = net_arch[-1]
# Separate feature extractor for SDE
if sde_net_arch is not None:
self.sde_feature_extractor, latent_sde_dim = create_sde_feature_extractor(obs_dim, sde_net_arch, activation_fn)
# TODO: check for the learn_features
self.action_dist = StateDependentNoiseDistribution(action_dim, full_std=full_std, use_expln=False,
learn_features=True, squash_output=True)
self.mu, self.log_std = self.action_dist.proba_distribution_net(latent_dim=net_arch[-1],
latent_sde_dim=latent_sde_dim,
log_std_init=log_std_init)
else:
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
@ -61,38 +71,55 @@ class Actor(BaseNetwork):
"""
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def _get_latent(self, obs):
latent_pi = self.latent_pi(obs)
if self.sde_feature_extractor is not None:
latent_sde = self.sde_feature_extractor(obs)
else:
latent_sde = latent_pi
return latent_pi, latent_sde
def get_action_dist_params(self, obs):
latent = self.actor_net(obs)
latent_pi, latent_sde = self._get_latent(obs)
if self.use_sde:
mean_actions, log_std = self.mu(latent), self.log_std
mean_actions, log_std = self.mu(latent_pi), self.log_std
else:
mean_actions, log_std = self.mu(latent), self.log_std(latent)
mean_actions, log_std = self.mu(latent_pi), self.log_std(latent_pi)
# Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, latent
return mean_actions, log_std, latent_sde
def forward(self, obs, deterministic=False):
mean_actions, log_std, latent = self.get_action_dist_params(obs)
mean_actions, log_std, latent_sde = self.get_action_dist_params(obs)
if self.use_sde:
# Note the action is squashed
action, _ = self.action_dist.proba_distribution(mean_actions, log_std, latent, deterministic=deterministic)
action, _ = self.action_dist.proba_distribution(mean_actions, log_std, latent_sde, deterministic=deterministic)
else:
# Note the action is squashed
action, _ = self.action_dist.proba_distribution(mean_actions, log_std, deterministic=deterministic)
return action
def action_log_prob(self, obs):
mean_actions, log_std, latent = self.get_action_dist_params(obs)
mean_actions, log_std, latent_sde = self.get_action_dist_params(obs)
if self.use_sde:
action, log_prob = self.action_dist.log_prob_from_params(mean_actions, self.log_std, latent)
action, log_prob = self.action_dist.log_prob_from_params(mean_actions, self.log_std, latent_sde)
else:
action, log_prob = self.action_dist.log_prob_from_params(mean_actions, log_std)
return action, log_prob
class Critic(BaseNetwork):
"""
Critic network (q-value function) for SAC.
:param obs_dim: (int) Dimension of the observation
:param action_dim: (int) Dimension of the action space
:param net_arch: ([int]) Network architecture
:param activation_fn: (nn.Module) Activation function
"""
def __init__(self, obs_dim, action_dim,
net_arch, activation_fn=nn.ReLU):
super(Critic, self).__init__()
@ -114,9 +141,25 @@ class Critic(BaseNetwork):
class SACPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: (gym.spaces.Space) Observation space
:param action_dim: (gym.spaces.Space) Action space
:param learning_rate: (callable) Learning rate schedule (could be constant)
:param net_arch: ([int or dict]) The specification of the policy and value networks.
:param device: (str or th.device) Device on which the code should run.
:param activation_fn: (nn.Module) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
"""
def __init__(self, observation_space, action_space,
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.ReLU, use_sde=False, log_std_init=-3):
activation_fn=nn.ReLU, use_sde=False,
log_std_init=-3, sde_net_arch=None):
super(SACPolicy, self).__init__(observation_space, action_space, device)
if net_arch is None:
@ -133,8 +176,12 @@ class SACPolicy(BasePolicy):
'activation_fn': self.activation_fn
}
self.actor_kwargs = self.net_args.copy()
self.actor_kwargs['use_sde'] = use_sde
self.actor_kwargs['log_std_init'] = log_std_init
sde_kwargs = {
'use_sde': use_sde,
'log_std_init': log_std_init,
'sde_net_arch': sde_net_arch
}
self.actor_kwargs.update(sde_kwargs)
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None

View file

@ -1,7 +1,7 @@
import torch as th
import torch.nn as nn
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, create_sde_feature_extractor
from torchy_baselines.common.distributions import StateDependentNoiseDistribution
@ -19,10 +19,13 @@ class Actor(BaseNetwork):
:param lr_sde: (float) Learning rate for the standard deviation of the noise
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using SDE.
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
"""
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
use_sde=False, log_std_init=-3, clip_noise=None,
lr_sde=3e-4, full_std=False):
lr_sde=3e-4, full_std=False, sde_net_arch=None):
super(Actor, self).__init__()
self.latent_pi, self.log_std = None, None
@ -30,23 +33,32 @@ class Actor(BaseNetwork):
self.use_sde, self.sde_optimizer = use_sde, None
self.action_dim = action_dim
self.full_std = full_std
self.sde_feature_extractor = None
if use_sde:
latent_pi = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
self.latent_pi = nn.Sequential(*latent_pi)
latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
self.latent_pi = nn.Sequential(*latent_pi_net)
latent_sde_dim = net_arch[-1]
learn_features = sde_net_arch is not None
# Separate feature extractor for SDE
if sde_net_arch is not None:
self.sde_feature_extractor, latent_sde_dim = create_sde_feature_extractor(obs_dim, sde_net_arch, activation_fn)
# Create state dependent noise matrix (SDE)
self.action_dist = StateDependentNoiseDistribution(action_dim, full_std=full_std, use_expln=False,
squash_output=False)
squash_output=False, learn_features=learn_features)
action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=net_arch[-1],
latent_sde_dim=latent_sde_dim,
log_std_init=log_std_init)
# Squash output
self.actor_net = nn.Sequential(action_net, nn.Tanh())
self.mu = nn.Sequential(action_net, nn.Tanh())
self.clip_noise = clip_noise
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
self.reset_noise()
else:
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
self.actor_net = nn.Sequential(*actor_net)
self.mu = nn.Sequential(*actor_net)
def get_std(self):
"""
@ -60,9 +72,18 @@ class Actor(BaseNetwork):
"""
return self.action_dist.get_std(self.log_std)
def _get_action_dist_from_latent(self, latent_pi):
mean_actions = self.actor_net(latent_pi)
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
def _get_action_dist_from_latent(self, latent_pi, latent_sde):
mean_actions = self.mu(latent_pi)
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
def _get_latent(self, obs):
latent_pi = self.latent_pi(obs)
if self.sde_feature_extractor is not None:
latent_sde = self.sde_feature_extractor(obs)
else:
latent_sde = latent_pi
return latent_pi, latent_sde
def evaluate_actions(self, obs, action):
"""
@ -75,9 +96,8 @@ class Actor(BaseNetwork):
:return: (th.Tensor, th.Tensor) log likelihood of taking those actions
and entropy of the action distribution.
"""
with th.no_grad():
latent_pi = self.latent_pi(obs)
_, distribution = self._get_action_dist_from_latent(latent_pi)
latent_pi, latent_sde = self._get_latent(obs)
_, distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
log_prob = distribution.log_prob(action)
# value = self.value_net(latent_vf)
return log_prob, distribution.entropy()
@ -90,20 +110,21 @@ class Actor(BaseNetwork):
def forward(self, obs, deterministic=True):
if self.use_sde:
latent_pi = self.latent_pi(obs)
latent_pi, latent_sde = self._get_latent(obs)
if deterministic:
return self.actor_net(latent_pi)
noise = self.action_dist.get_noise(latent_pi)
return self.mu(latent_pi)
noise = self.action_dist.get_noise(latent_sde)
if self.clip_noise is not None:
noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
# TODO: Replace with squashing -> need to account for that in the sde update
# -> set squash_out=True in the action_dist?
# NOTE: the clipping is done in the rollout for now
return self.actor_net(latent_pi) + noise
return self.mu(latent_pi) + noise
# action, _ = self._get_action_dist_from_latent(latent_pi)
# return action
else:
return self.actor_net(obs)
return self.mu(obs)
class Critic(BaseNetwork):
@ -148,11 +169,14 @@ class TD3Policy(BasePolicy):
:param activation_fn: (nn.Module) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
"""
def __init__(self, observation_space, action_space,
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.ReLU, use_sde=False, log_std_init=-3,
clip_noise=None, lr_sde=3e-4):
clip_noise=None, lr_sde=3e-4, sde_net_arch=None):
super(TD3Policy, self).__init__(observation_space, action_space, device)
# Default network architecture, from the original paper
@ -170,10 +194,14 @@ class TD3Policy(BasePolicy):
'activation_fn': self.activation_fn
}
self.actor_kwargs = self.net_args.copy()
self.actor_kwargs['use_sde'] = use_sde
self.actor_kwargs['log_std_init'] = log_std_init
self.actor_kwargs['clip_noise'] = clip_noise
self.actor_kwargs['lr_sde'] = lr_sde
sde_kwargs = {
'use_sde': use_sde,
'log_std_init': log_std_init,
'clip_noise': clip_noise,
'lr_sde': lr_sde,
'sde_net_arch': sde_net_arch
}
self.actor_kwargs.update(sde_kwargs)
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None