diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index b01345a..f536624 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -93,9 +93,30 @@ def create_mlp(input_dim, output_dim, net_arch, return modules -class BaseNetwork(nn.Module): - """docstring for BaseNetwork.""" +def create_sde_feature_extractor(features_dim, sde_net_arch, activation_fn): + """ + Create the neural network that will be used to extract features + for the SDE. + :param features_dim: (int) + :param sde_net_arch: ([int]) + :param activation_fn: (nn.Module) + :return: (nn.Sequential, int) + """ + # Special case: when using states as features (i.e. sde_net_arch is an empty list) + # don't use any activation function + sde_activation = activation_fn if len(sde_net_arch) > 0 else None + latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_out=False) + latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim + sde_feature_extractor = nn.Sequential(*latent_sde_net) + return sde_feature_extractor, latent_sde_dim + + +class BaseNetwork(nn.Module): + """ + Abstract class for the different networks (actor/critic) + that implements two helpers for using CEM with their weights. + """ def __init__(self): super(BaseNetwork, self).__init__() diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 50dbeea..9b38317 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -25,6 +25,9 @@ class PPOPolicy(BasePolicy): :param log_std_init: (float) Initial value for the log standard deviation :param full_std: (bool) Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using SDE + :param sde_net_arch: ([int]) Network architecture for extracting features + when using SDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. """ def __init__(self, observation_space, action_space, learning_rate, net_arch=None, device='cpu', @@ -34,7 +37,6 @@ class PPOPolicy(BasePolicy): super(PPOPolicy, self).__init__(observation_space, action_space, device) self.obs_dim = self.observation_space.shape[0] - # Default network architecture, from stable-baselines if net_arch is None: net_arch = [dict(pi=[64, 64], vf=[64, 64])] @@ -87,13 +89,8 @@ class PPOPolicy(BasePolicy): # Separate feature extractor for SDE if self.sde_net_arch is not None: - # Special case: when using states as features (i.e. sde_net_arch is an empty list) - # don't use any activation function - sde_activation = self.activation_fn if len(self.sde_net_arch) > 0 else None - latent_sde = create_mlp(self.features_dim, -1, self.sde_net_arch, - activation_fn=sde_activation, squash_out=False) - self.sde_feature_extractor = nn.Sequential(*latent_sde) - latent_sde_dim = self.sde_net_arch[-1] if len(self.sde_net_arch) > 0 else self.features_dim + self.sde_feature_extractor, latent_sde_dim = create_sde_feature_extractor(self.features_dim, self.sde_net_arch, + self.activation_fn) if isinstance(self.action_dist, DiagGaussianDistribution): self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi, diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 0a08335..be32660 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -1,7 +1,7 @@ import torch as th import torch.nn as nn -from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork +from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, create_sde_feature_extractor from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution # CAP the standard deviation of the actor @@ -21,20 +21,30 @@ class Actor(BaseNetwork): :param log_std_init: (float) Initial value for the log standard deviation :param full_std: (bool) Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using SDE. + :param sde_net_arch: ([int]) Network architecture for extracting features + when using SDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. """ def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU, - use_sde=False, log_std_init=-3, full_std=True): + use_sde=False, log_std_init=-3, full_std=True, sde_net_arch=None): super(Actor, self).__init__() - actor_net = create_mlp(obs_dim, -1, net_arch, activation_fn) - self.actor_net = nn.Sequential(*actor_net) + latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn) + self.latent_pi = nn.Sequential(*latent_pi_net) self.use_sde = use_sde + self.sde_feature_extractor = None if self.use_sde: + latent_sde_dim = net_arch[-1] + # Separate feature extractor for SDE + if sde_net_arch is not None: + self.sde_feature_extractor, latent_sde_dim = create_sde_feature_extractor(obs_dim, sde_net_arch, activation_fn) + # TODO: check for the learn_features self.action_dist = StateDependentNoiseDistribution(action_dim, full_std=full_std, use_expln=False, learn_features=True, squash_output=True) self.mu, self.log_std = self.action_dist.proba_distribution_net(latent_dim=net_arch[-1], + latent_sde_dim=latent_sde_dim, log_std_init=log_std_init) else: self.action_dist = SquashedDiagGaussianDistribution(action_dim) @@ -61,38 +71,55 @@ class Actor(BaseNetwork): """ self.action_dist.sample_weights(self.log_std, batch_size=batch_size) + def _get_latent(self, obs): + latent_pi = self.latent_pi(obs) + + if self.sde_feature_extractor is not None: + latent_sde = self.sde_feature_extractor(obs) + else: + latent_sde = latent_pi + return latent_pi, latent_sde + def get_action_dist_params(self, obs): - latent = self.actor_net(obs) + latent_pi, latent_sde = self._get_latent(obs) if self.use_sde: - mean_actions, log_std = self.mu(latent), self.log_std + mean_actions, log_std = self.mu(latent_pi), self.log_std else: - mean_actions, log_std = self.mu(latent), self.log_std(latent) + mean_actions, log_std = self.mu(latent_pi), self.log_std(latent_pi) # Original Implementation to cap the standard deviation log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) - return mean_actions, log_std, latent + return mean_actions, log_std, latent_sde def forward(self, obs, deterministic=False): - mean_actions, log_std, latent = self.get_action_dist_params(obs) + mean_actions, log_std, latent_sde = self.get_action_dist_params(obs) if self.use_sde: # Note the action is squashed - action, _ = self.action_dist.proba_distribution(mean_actions, log_std, latent, deterministic=deterministic) + action, _ = self.action_dist.proba_distribution(mean_actions, log_std, latent_sde, deterministic=deterministic) else: # Note the action is squashed action, _ = self.action_dist.proba_distribution(mean_actions, log_std, deterministic=deterministic) return action def action_log_prob(self, obs): - mean_actions, log_std, latent = self.get_action_dist_params(obs) + mean_actions, log_std, latent_sde = self.get_action_dist_params(obs) if self.use_sde: - action, log_prob = self.action_dist.log_prob_from_params(mean_actions, self.log_std, latent) + action, log_prob = self.action_dist.log_prob_from_params(mean_actions, self.log_std, latent_sde) else: action, log_prob = self.action_dist.log_prob_from_params(mean_actions, log_std) return action, log_prob class Critic(BaseNetwork): + """ + Critic network (q-value function) for SAC. + + :param obs_dim: (int) Dimension of the observation + :param action_dim: (int) Dimension of the action space + :param net_arch: ([int]) Network architecture + :param activation_fn: (nn.Module) Activation function + """ def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU): super(Critic, self).__init__() @@ -114,9 +141,25 @@ class Critic(BaseNetwork): class SACPolicy(BasePolicy): + """ + Policy class (with both actor and critic) for SAC. + + :param observation_space: (gym.spaces.Space) Observation space + :param action_dim: (gym.spaces.Space) Action space + :param learning_rate: (callable) Learning rate schedule (could be constant) + :param net_arch: ([int or dict]) The specification of the policy and value networks. + :param device: (str or th.device) Device on which the code should run. + :param activation_fn: (nn.Module) Activation function + :param use_sde: (bool) Whether to use State Dependent Exploration or not + :param log_std_init: (float) Initial value for the log standard deviation + :param sde_net_arch: ([int]) Network architecture for extracting features + when using SDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + """ def __init__(self, observation_space, action_space, learning_rate, net_arch=None, device='cpu', - activation_fn=nn.ReLU, use_sde=False, log_std_init=-3): + activation_fn=nn.ReLU, use_sde=False, + log_std_init=-3, sde_net_arch=None): super(SACPolicy, self).__init__(observation_space, action_space, device) if net_arch is None: @@ -133,8 +176,12 @@ class SACPolicy(BasePolicy): 'activation_fn': self.activation_fn } self.actor_kwargs = self.net_args.copy() - self.actor_kwargs['use_sde'] = use_sde - self.actor_kwargs['log_std_init'] = log_std_init + sde_kwargs = { + 'use_sde': use_sde, + 'log_std_init': log_std_init, + 'sde_net_arch': sde_net_arch + } + self.actor_kwargs.update(sde_kwargs) self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 6872213..65f2171 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -1,7 +1,7 @@ import torch as th import torch.nn as nn -from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork +from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, create_sde_feature_extractor from torchy_baselines.common.distributions import StateDependentNoiseDistribution @@ -19,10 +19,13 @@ class Actor(BaseNetwork): :param lr_sde: (float) Learning rate for the standard deviation of the noise :param full_std: (bool) Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using SDE. + :param sde_net_arch: ([int]) Network architecture for extracting features + when using SDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. """ def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU, use_sde=False, log_std_init=-3, clip_noise=None, - lr_sde=3e-4, full_std=False): + lr_sde=3e-4, full_std=False, sde_net_arch=None): super(Actor, self).__init__() self.latent_pi, self.log_std = None, None @@ -30,23 +33,32 @@ class Actor(BaseNetwork): self.use_sde, self.sde_optimizer = use_sde, None self.action_dim = action_dim self.full_std = full_std + self.sde_feature_extractor = None if use_sde: - latent_pi = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False) - self.latent_pi = nn.Sequential(*latent_pi) + latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False) + self.latent_pi = nn.Sequential(*latent_pi_net) + latent_sde_dim = net_arch[-1] + learn_features = sde_net_arch is not None + + # Separate feature extractor for SDE + if sde_net_arch is not None: + self.sde_feature_extractor, latent_sde_dim = create_sde_feature_extractor(obs_dim, sde_net_arch, activation_fn) + # Create state dependent noise matrix (SDE) self.action_dist = StateDependentNoiseDistribution(action_dim, full_std=full_std, use_expln=False, - squash_output=False) + squash_output=False, learn_features=learn_features) action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=net_arch[-1], + latent_sde_dim=latent_sde_dim, log_std_init=log_std_init) # Squash output - self.actor_net = nn.Sequential(action_net, nn.Tanh()) + self.mu = nn.Sequential(action_net, nn.Tanh()) self.clip_noise = clip_noise self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde) self.reset_noise() else: actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True) - self.actor_net = nn.Sequential(*actor_net) + self.mu = nn.Sequential(*actor_net) def get_std(self): """ @@ -60,9 +72,18 @@ class Actor(BaseNetwork): """ return self.action_dist.get_std(self.log_std) - def _get_action_dist_from_latent(self, latent_pi): - mean_actions = self.actor_net(latent_pi) - return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) + def _get_action_dist_from_latent(self, latent_pi, latent_sde): + mean_actions = self.mu(latent_pi) + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) + + def _get_latent(self, obs): + latent_pi = self.latent_pi(obs) + + if self.sde_feature_extractor is not None: + latent_sde = self.sde_feature_extractor(obs) + else: + latent_sde = latent_pi + return latent_pi, latent_sde def evaluate_actions(self, obs, action): """ @@ -75,9 +96,8 @@ class Actor(BaseNetwork): :return: (th.Tensor, th.Tensor) log likelihood of taking those actions and entropy of the action distribution. """ - with th.no_grad(): - latent_pi = self.latent_pi(obs) - _, distribution = self._get_action_dist_from_latent(latent_pi) + latent_pi, latent_sde = self._get_latent(obs) + _, distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) log_prob = distribution.log_prob(action) # value = self.value_net(latent_vf) return log_prob, distribution.entropy() @@ -90,20 +110,21 @@ class Actor(BaseNetwork): def forward(self, obs, deterministic=True): if self.use_sde: - latent_pi = self.latent_pi(obs) + latent_pi, latent_sde = self._get_latent(obs) if deterministic: - return self.actor_net(latent_pi) - noise = self.action_dist.get_noise(latent_pi) + return self.mu(latent_pi) + + noise = self.action_dist.get_noise(latent_sde) if self.clip_noise is not None: noise = th.clamp(noise, -self.clip_noise, self.clip_noise) # TODO: Replace with squashing -> need to account for that in the sde update # -> set squash_out=True in the action_dist? # NOTE: the clipping is done in the rollout for now - return self.actor_net(latent_pi) + noise + return self.mu(latent_pi) + noise # action, _ = self._get_action_dist_from_latent(latent_pi) # return action else: - return self.actor_net(obs) + return self.mu(obs) class Critic(BaseNetwork): @@ -148,11 +169,14 @@ class TD3Policy(BasePolicy): :param activation_fn: (nn.Module) Activation function :param use_sde: (bool) Whether to use State Dependent Exploration or not :param log_std_init: (float) Initial value for the log standard deviation + :param sde_net_arch: ([int]) Network architecture for extracting features + when using SDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. """ def __init__(self, observation_space, action_space, learning_rate, net_arch=None, device='cpu', activation_fn=nn.ReLU, use_sde=False, log_std_init=-3, - clip_noise=None, lr_sde=3e-4): + clip_noise=None, lr_sde=3e-4, sde_net_arch=None): super(TD3Policy, self).__init__(observation_space, action_space, device) # Default network architecture, from the original paper @@ -170,10 +194,14 @@ class TD3Policy(BasePolicy): 'activation_fn': self.activation_fn } self.actor_kwargs = self.net_args.copy() - self.actor_kwargs['use_sde'] = use_sde - self.actor_kwargs['log_std_init'] = log_std_init - self.actor_kwargs['clip_noise'] = clip_noise - self.actor_kwargs['lr_sde'] = lr_sde + sde_kwargs = { + 'use_sde': use_sde, + 'log_std_init': log_std_init, + 'clip_noise': clip_noise, + 'lr_sde': lr_sde, + 'sde_net_arch': sde_net_arch + } + self.actor_kwargs.update(sde_kwargs) self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None