Refactor: add distributions

This commit is contained in:
Antonin RAFFIN 2019-09-22 12:52:49 +02:00
parent 70e1d673a9
commit ddaafcbc36
3 changed files with 79 additions and 34 deletions

View file

@ -1,7 +1,8 @@
import numpy as np
import torch as th
import torch.nn as nn
from torch.distributions import Normal
class Distribution(object):
def __init__(self):
super(Distribution, self).__init__()
@ -19,7 +20,7 @@ class Distribution(object):
"""
Calculates the Kullback-Leibler divergence from the given probabilty distribution
:param other: ([float]) the distibution to compare with
:param other: ([float]) the distribution to compare with
:return: (float) the KL divergence of the two distributions
"""
raise NotImplementedError
@ -41,15 +42,72 @@ class Distribution(object):
raise NotImplementedError
class DiagGaussianDistribution(object):
"""docstring for DiagGaussianDistribution."""
def __init__(self):
class DiagGaussianDistribution(Distribution):
def __init__(self, action_dim):
super(DiagGaussianDistribution, self).__init__()
self.distribution = None
self.action_dim = action_dim
self.mean_actions = None
self.log_std = None
def proba_distribution_from_latent(self, latent, init_scale=1.0, init_bias=0.0):
self.distribution = Normal()
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
mean_actions = nn.Linear(latent_dim, self.action_dim)
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init)
return mean_actions, log_std
def proba_distribution(self, mean_actions, log_std, deterministic=False):
action_std = th.ones_like(mean_actions) * log_std.exp()
self.distribution = Normal(mean_actions, action_std)
if deterministic:
action = self.mode()
else:
action = self.sample()
return action, self
def mode(self):
return self.distribution.mean
def sample(self):
return self.distribution.rsample()
def entropy(self):
return self.distribution.entropy()
def log_prob(self, action):
log_prob = self.distribution.log_prob(action)
if len(log_prob.shape) > 1:
log_prob = log_prob.sum(axis=1)
else:
log_prob = log_prob.sum()
return log_prob
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
def __init__(self, action_dim, epsilon=1e-6):
super(SquashedDiagGaussianDistribution, self).__init__(action_dim)
# Avoid NaN (prevents division by zero or log of zero)
self.epsilon = epsilon
def proba_distribution(self, mean_actions, log_std, deterministic=False):
action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std, deterministic)
return action, self
def mode(self):
# Squash the output
return th.tanh(self.distribution.mean)
def sample(self):
return th.tanh(self.distribution.rsample())
def log_prob(self, action, gaussian_action=None):
# Inverse tanh
# Naive implementation (not stable): 0.5 * torch.log((1 + x ) / (1 - x))
# We use numpy to avoid numerical instability
if gaussian_action is None:
gaussian_action = th.from_numpy(np.arctanh(action.cpu().numpy())).to(action.device)
# Log likelihood for a gaussian distribution
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_action)
# Squash correction (from original implementation)
log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1)
return log_prob

View file

@ -6,7 +6,7 @@ from torch.distributions import Normal
import numpy as np
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp
from torchy_baselines.common.distributions import DiagGaussianDistribution, SquashedDiagGaussianDistribution
class PPOPolicy(BasePolicy):
def __init__(self, observation_space, action_space,
@ -28,6 +28,9 @@ class PPOPolicy(BasePolicy):
}
self.shared_net = None
self.pi_net, self.vf_net = None, None
# Action distribution
# self.action_dist = DiagGaussianDistribution(self.action_dim)
self.action_dist = SquashedDiagGaussianDistribution(self.action_dim)
self._build(learning_rate)
@staticmethod
@ -46,17 +49,18 @@ class PPOPolicy(BasePolicy):
vf_net = create_mlp(self.state_dim, output_dim=-1, net_arch=self.net_arch, activation_fn=self.activation_fn)
self.vf_net = nn.Sequential(*vf_net).to(self.device)
self.actor_net = nn.Linear(self.net_arch[-1], self.action_dim)
# self.action_net = nn.Linear(self.net_arch[-1], self.action_dim)
# self.log_std = nn.Parameter(th.zeros(self.action_dim))
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.net_arch[-1])
self.value_net = nn.Linear(self.net_arch[-1], 1)
self.log_std = nn.Parameter(th.zeros(self.action_dim))
# Init weights: use orthogonal initialization
for module in [self.pi_net, self.vf_net, self.actor_net, self.value_net]:
for module in [self.pi_net, self.vf_net, self.action_net, self.value_net]:
# Values from stable-baselines check why
gain = {
self.pi_net: np.sqrt(2),
self.vf_net: np.sqrt(2),
self.shared_net: np.sqrt(2),
self.actor_net: 0.01,
self.action_net: 0.01,
self.value_net: 1
}[module]
module.apply(partial(self.init_weights, gain=gain))
@ -68,7 +72,7 @@ class PPOPolicy(BasePolicy):
latent_pi, latent_vf = self._get_latent(state)
value = self.value_net(latent_vf)
action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
log_prob = self._get_log_prob(action_distribution, action)
log_prob = action_distribution.log_prob(action)
return action, value, log_prob
def _get_latent(self, state):
@ -79,24 +83,8 @@ class PPOPolicy(BasePolicy):
return self.pi_net(state), self.vf_net(state)
def _get_action_dist_from_latent(self, latent, deterministic=False):
mean_actions = self.actor_net(latent)
action_std = th.ones_like(mean_actions) * self.log_std.exp()
action_distribution = Normal(mean_actions, action_std)
# Sample from the gaussian
if deterministic:
action = mean_actions
else:
action = action_distribution.rsample()
return action, action_distribution
@staticmethod
def _get_log_prob(action_distribution, action):
log_prob = action_distribution.log_prob(action)
if len(log_prob.shape) > 1:
log_prob = log_prob.sum(axis=1)
else:
log_prob = log_prob.sum()
return log_prob
mean_actions = self.action_net(latent)
return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic)
def actor_forward(self, state, deterministic=False):
latent_pi, _ = self._get_latent(state)
@ -106,7 +94,7 @@ class PPOPolicy(BasePolicy):
def get_policy_stats(self, state, action):
latent_pi, latent_vf = self._get_latent(state)
_, action_distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = self._get_log_prob(action_distribution, action)
log_prob = action_distribution.log_prob(action)
value = self.value_net(latent_vf)
return value, log_prob, action_distribution.entropy()

View file

@ -126,7 +126,6 @@ class PPO(BaseRLModel):
for replay_data in self.rollout_buffer.get(batch_size):
# Unpack
state, action, old_values, old_log_prob, advantage, return_batch = replay_data
values, log_prob, entropy = self.policy.get_policy_stats(state, action)
values = values.flatten()
# Normalize advantage