diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index 83e4fde..e188667 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -1,7 +1,8 @@ +import numpy as np import torch as th +import torch.nn as nn from torch.distributions import Normal - class Distribution(object): def __init__(self): super(Distribution, self).__init__() @@ -19,7 +20,7 @@ class Distribution(object): """ Calculates the Kullback-Leibler divergence from the given probabilty distribution - :param other: ([float]) the distibution to compare with + :param other: ([float]) the distribution to compare with :return: (float) the KL divergence of the two distributions """ raise NotImplementedError @@ -41,15 +42,72 @@ class Distribution(object): raise NotImplementedError -class DiagGaussianDistribution(object): - """docstring for DiagGaussianDistribution.""" - - def __init__(self): +class DiagGaussianDistribution(Distribution): + def __init__(self, action_dim): super(DiagGaussianDistribution, self).__init__() self.distribution = None + self.action_dim = action_dim + self.mean_actions = None + self.log_std = None - def proba_distribution_from_latent(self, latent, init_scale=1.0, init_bias=0.0): - self.distribution = Normal() + def proba_distribution_net(self, latent_dim, log_std_init=0.0): + mean_actions = nn.Linear(latent_dim, self.action_dim) + log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init) + return mean_actions, log_std + + def proba_distribution(self, mean_actions, log_std, deterministic=False): + action_std = th.ones_like(mean_actions) * log_std.exp() + self.distribution = Normal(mean_actions, action_std) + if deterministic: + action = self.mode() + else: + action = self.sample() + return action, self + + def mode(self): + return self.distribution.mean def sample(self): return self.distribution.rsample() + + def entropy(self): + return self.distribution.entropy() + + def log_prob(self, action): + log_prob = self.distribution.log_prob(action) + if len(log_prob.shape) > 1: + log_prob = log_prob.sum(axis=1) + else: + log_prob = log_prob.sum() + return log_prob + + +class SquashedDiagGaussianDistribution(DiagGaussianDistribution): + def __init__(self, action_dim, epsilon=1e-6): + super(SquashedDiagGaussianDistribution, self).__init__(action_dim) + # Avoid NaN (prevents division by zero or log of zero) + self.epsilon = epsilon + + def proba_distribution(self, mean_actions, log_std, deterministic=False): + action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std, deterministic) + return action, self + + def mode(self): + # Squash the output + return th.tanh(self.distribution.mean) + + def sample(self): + return th.tanh(self.distribution.rsample()) + + def log_prob(self, action, gaussian_action=None): + # Inverse tanh + # Naive implementation (not stable): 0.5 * torch.log((1 + x ) / (1 - x)) + # We use numpy to avoid numerical instability + if gaussian_action is None: + gaussian_action = th.from_numpy(np.arctanh(action.cpu().numpy())).to(action.device) + + # Log likelihood for a gaussian distribution + log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_action) + # Squash correction (from original implementation) + log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1) + return log_prob diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 31e8bb1..3ed9a28 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -6,7 +6,7 @@ from torch.distributions import Normal import numpy as np from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp - +from torchy_baselines.common.distributions import DiagGaussianDistribution, SquashedDiagGaussianDistribution class PPOPolicy(BasePolicy): def __init__(self, observation_space, action_space, @@ -28,6 +28,9 @@ class PPOPolicy(BasePolicy): } self.shared_net = None self.pi_net, self.vf_net = None, None + # Action distribution + # self.action_dist = DiagGaussianDistribution(self.action_dim) + self.action_dist = SquashedDiagGaussianDistribution(self.action_dim) self._build(learning_rate) @staticmethod @@ -46,17 +49,18 @@ class PPOPolicy(BasePolicy): vf_net = create_mlp(self.state_dim, output_dim=-1, net_arch=self.net_arch, activation_fn=self.activation_fn) self.vf_net = nn.Sequential(*vf_net).to(self.device) - self.actor_net = nn.Linear(self.net_arch[-1], self.action_dim) + # self.action_net = nn.Linear(self.net_arch[-1], self.action_dim) + # self.log_std = nn.Parameter(th.zeros(self.action_dim)) + self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.net_arch[-1]) self.value_net = nn.Linear(self.net_arch[-1], 1) - self.log_std = nn.Parameter(th.zeros(self.action_dim)) # Init weights: use orthogonal initialization - for module in [self.pi_net, self.vf_net, self.actor_net, self.value_net]: + for module in [self.pi_net, self.vf_net, self.action_net, self.value_net]: # Values from stable-baselines check why gain = { self.pi_net: np.sqrt(2), self.vf_net: np.sqrt(2), self.shared_net: np.sqrt(2), - self.actor_net: 0.01, + self.action_net: 0.01, self.value_net: 1 }[module] module.apply(partial(self.init_weights, gain=gain)) @@ -68,7 +72,7 @@ class PPOPolicy(BasePolicy): latent_pi, latent_vf = self._get_latent(state) value = self.value_net(latent_vf) action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) - log_prob = self._get_log_prob(action_distribution, action) + log_prob = action_distribution.log_prob(action) return action, value, log_prob def _get_latent(self, state): @@ -79,24 +83,8 @@ class PPOPolicy(BasePolicy): return self.pi_net(state), self.vf_net(state) def _get_action_dist_from_latent(self, latent, deterministic=False): - mean_actions = self.actor_net(latent) - action_std = th.ones_like(mean_actions) * self.log_std.exp() - action_distribution = Normal(mean_actions, action_std) - # Sample from the gaussian - if deterministic: - action = mean_actions - else: - action = action_distribution.rsample() - return action, action_distribution - - @staticmethod - def _get_log_prob(action_distribution, action): - log_prob = action_distribution.log_prob(action) - if len(log_prob.shape) > 1: - log_prob = log_prob.sum(axis=1) - else: - log_prob = log_prob.sum() - return log_prob + mean_actions = self.action_net(latent) + return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic) def actor_forward(self, state, deterministic=False): latent_pi, _ = self._get_latent(state) @@ -106,7 +94,7 @@ class PPOPolicy(BasePolicy): def get_policy_stats(self, state, action): latent_pi, latent_vf = self._get_latent(state) _, action_distribution = self._get_action_dist_from_latent(latent_pi) - log_prob = self._get_log_prob(action_distribution, action) + log_prob = action_distribution.log_prob(action) value = self.value_net(latent_vf) return value, log_prob, action_distribution.entropy() diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index ad920f8..8d5d902 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -126,7 +126,6 @@ class PPO(BaseRLModel): for replay_data in self.rollout_buffer.get(batch_size): # Unpack state, action, old_values, old_log_prob, advantage, return_batch = replay_data - values, log_prob, entropy = self.policy.get_policy_stats(state, action) values = values.flatten() # Normalize advantage