diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 51eb846..d355e09 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -112,7 +112,6 @@ class A2C(PPO): # Optimization step self.policy.optimizer.zero_grad() loss.backward() - # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index a62ede8..72bf7b2 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -2,6 +2,7 @@ import numpy as np import torch as th import torch.nn as nn from torch.distributions import Normal, Categorical +import torch.nn.functional as F from gym import spaces class Distribution(object): @@ -168,20 +169,20 @@ class CategoricalDistribution(Distribution): class StateDependentNoiseDistribution(Distribution): - def __init__(self, features_dim, action_dim, use_expln=False, - squash_output=True, epsilon=1e-6): + def __init__(self, action_dim, use_expln=False, + squash_output=False, epsilon=1e-6): super(StateDependentNoiseDistribution, self).__init__() self.distribution = None self.action_dim = action_dim - self.features_dim = features_dim self.mean_actions = None self.log_std = None self.weights_dist = None - self.noise_weights = None - self.gaussian_action = None + self.exploration_mat = None self.use_expln = use_expln - self.squash_output = squash_output - self.epsilon = epsilon + if squash_output: + self.bijector = TanhBijector(epsilon) + else: + self.bijector = None def get_std(self, log_std): if self.use_expln: @@ -196,71 +197,103 @@ class StateDependentNoiseDistribution(Distribution): def sample_weights(self, log_std): self.weights_dist = Normal(th.zeros_like(log_std), self.get_std(log_std)) - self.noise_weights = self.weights_dist.rsample() + self.exploration_mat = self.weights_dist.rsample() def proba_distribution_net(self, latent_dim, log_std_init=-1): mean_actions = nn.Linear(latent_dim, self.action_dim) # TODO: log_std_init depending on the number of layers? - log_std = nn.Parameter(th.ones(self.features_dim, self.action_dim) * log_std_init) + log_std = nn.Parameter(th.ones(latent_dim, self.action_dim) * log_std_init) self.sample_weights(log_std) return mean_actions, log_std - def proba_distribution(self, mean_actions, log_std, observations, deterministic=False): - variance = th.mm(observations ** 2, self.get_std(log_std) ** 2) + def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False): + # TODO: try without detach + variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2) self.distribution = Normal(mean_actions, th.sqrt(variance)) if deterministic: action = self.mode() else: - action = self.sample(observations) + action = self.sample(latent_pi) return action, self def mode(self): - self.gaussian_action = self.distribution.mean - if self.squash_output: - return th.tanh(self.gaussian_action) - return self.gaussian_action + action = self.distribution.mean + if self.bijector is not None: + return self.bijector.forward(action) + return action - def sample(self, observations): - noise = th.mm(observations, self.noise_weights) - self.gaussian_action = self.distribution.mean + noise - if self.squash_output: - return th.tanh(self.gaussian_action) - return self.gaussian_action + def sample(self, latent_pi): + noise = th.mm(latent_pi.detach(), self.exploration_mat) + action = self.distribution.mean + noise + if self.bijector is not None: + return self.bijector.forward(action) + return action def entropy(self): # TODO: account for the squashing? return self.distribution.entropy() - def log_prob_from_params(self, mean_actions, log_std, observations): - action, _ = self.proba_distribution(mean_actions, log_std, observations) + def log_prob_from_params(self, mean_actions, log_std, latent_pi): + action, _ = self.proba_distribution(mean_actions, log_std, latent_pi) log_prob = self.log_prob(action) return action, log_prob def log_prob(self, action): - if self.squash_output: - gaussian_action = self.gaussian_action + if self.bijector is not None: + gaussian_action = self.bijector.inverse(action) else: gaussian_action = action # log likelihood for a gaussian log_prob = self.distribution.log_prob(gaussian_action) - # log_prob = self.distribution.log_prob(action) + if len(log_prob.shape) > 1: log_prob = log_prob.sum(axis=1) else: log_prob = log_prob.sum() - if self.squash_output: + + if self.bijector is not None: # Squash correction (from original SAC implementation) - log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1) + log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_action), dim=1) return log_prob -def make_proba_distribution(action_space, features_dim=None, use_sde=False): +class TanhBijector(object): + def __init__(self, epsilon=1e-6): + super(TanhBijector, self).__init__() + self.epsilon = epsilon + + def forward(self, x): + return th.tanh(x) + + def inverse(self, action): + """ + Inverse tanh. + + From https://github.com/tensorflow/agents: + 0.99999997 is the maximum value such that atanh(x) is valid for both + float32 and float64 + + :param action: (th.Tensor) + :return: (th.Tensor) + """ + # Inverse tanh + # Naive implementation (not stable): 0.5 * torch.log((1 + x ) / (1 - x)) + # We use numpy to avoid numerical instability + # Note: Using numpy, we do not keep the gradient + clipped_action = np.clip(action.cpu().numpy(), -0.99999997, 0.99999997) + return th.from_numpy(np.arctanh(clipped_action)).to(action.device) + + def log_prob_correction(self, x): + # Squash correction (from original SAC implementation) + return th.log(1 - th.tanh(x) ** 2 + self.epsilon) + + +def make_proba_distribution(action_space, use_sde=False): """ Return an instance of Distribution for the correct type of action space :param action_space: (Gym Space) the input action space - :param feature_dim: (int) Dimension of the feature vector :param use_sde: (bool) Force the use of StateDependentNoiseDistribution instead of DiagGaussianDistribution :return: (Distribution) the approriate Distribution object @@ -268,7 +301,7 @@ def make_proba_distribution(action_space, features_dim=None, use_sde=False): if isinstance(action_space, spaces.Box): assert len(action_space.shape) == 1, "Error: the action space must be a vector" if use_sde: - return StateDependentNoiseDistribution(features_dim, action_space.shape[0]) + return StateDependentNoiseDistribution(action_space.shape[0]) return DiagGaussianDistribution(action_space.shape[0]) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n) diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 25d0bf3..1e3b25c 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -125,7 +125,7 @@ class PPOPolicy(BasePolicy): self.features_dim = self.obs_dim self.log_std_init = log_std_init # Action distribution - self.action_dist = make_proba_distribution(action_space, self.features_dim, use_sde=use_sde) + self.action_dist = make_proba_distribution(action_space, use_sde=use_sde) self._build(learning_rate) @@ -161,15 +161,15 @@ class PPOPolicy(BasePolicy): obs = th.FloatTensor(obs).to(self.device) latent_pi, latent_vf = self._get_latent(obs) value = self.value_net(latent_vf) - action, action_distribution = self._get_action_dist_from_latent(latent_pi, obs, deterministic=deterministic) + action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) log_prob = action_distribution.log_prob(action) return action, value, log_prob def _get_latent(self, obs): return self.mlp_extractor(self.features_extractor(obs)) - def _get_action_dist_from_latent(self, latent, obs, deterministic=False): - mean_actions = self.action_net(latent) + def _get_action_dist_from_latent(self, latent_pi, deterministic=False): + mean_actions = self.action_net(latent_pi) if isinstance(self.action_dist, DiagGaussianDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic) @@ -178,16 +178,16 @@ class PPOPolicy(BasePolicy): return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic) elif isinstance(self.action_dist, StateDependentNoiseDistribution): - return self.action_dist.proba_distribution(mean_actions, self.log_std, obs, deterministic=deterministic) + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi, deterministic=deterministic) def actor_forward(self, obs, deterministic=False): latent_pi, _ = self._get_latent(obs) - action, _ = self._get_action_dist_from_latent(latent_pi, obs, deterministic=deterministic) + action, _ = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) return action.detach().cpu().numpy() - def get_policy_stats(self, obs, action): + def get_policy_stats(self, obs, action, deterministic=False): latent_pi, latent_vf = self._get_latent(obs) - _, action_distribution = self._get_action_dist_from_latent(latent_pi, obs) + _, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) log_prob = action_distribution.log_prob(action) value = self.value_net(latent_vf) return value, log_prob, action_distribution.entropy()