SDE on latent_pi

This commit is contained in:
Antonin Raffin 2019-10-31 11:44:27 +01:00
parent 862ae666b5
commit 925afe784c
3 changed files with 73 additions and 41 deletions

View file

@ -112,7 +112,6 @@ class A2C(PPO):
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()

View file

@ -2,6 +2,7 @@ import numpy as np
import torch as th
import torch.nn as nn
from torch.distributions import Normal, Categorical
import torch.nn.functional as F
from gym import spaces
class Distribution(object):
@ -168,20 +169,20 @@ class CategoricalDistribution(Distribution):
class StateDependentNoiseDistribution(Distribution):
def __init__(self, features_dim, action_dim, use_expln=False,
squash_output=True, epsilon=1e-6):
def __init__(self, action_dim, use_expln=False,
squash_output=False, epsilon=1e-6):
super(StateDependentNoiseDistribution, self).__init__()
self.distribution = None
self.action_dim = action_dim
self.features_dim = features_dim
self.mean_actions = None
self.log_std = None
self.weights_dist = None
self.noise_weights = None
self.gaussian_action = None
self.exploration_mat = None
self.use_expln = use_expln
self.squash_output = squash_output
self.epsilon = epsilon
if squash_output:
self.bijector = TanhBijector(epsilon)
else:
self.bijector = None
def get_std(self, log_std):
if self.use_expln:
@ -196,71 +197,103 @@ class StateDependentNoiseDistribution(Distribution):
def sample_weights(self, log_std):
self.weights_dist = Normal(th.zeros_like(log_std), self.get_std(log_std))
self.noise_weights = self.weights_dist.rsample()
self.exploration_mat = self.weights_dist.rsample()
def proba_distribution_net(self, latent_dim, log_std_init=-1):
mean_actions = nn.Linear(latent_dim, self.action_dim)
# TODO: log_std_init depending on the number of layers?
log_std = nn.Parameter(th.ones(self.features_dim, self.action_dim) * log_std_init)
log_std = nn.Parameter(th.ones(latent_dim, self.action_dim) * log_std_init)
self.sample_weights(log_std)
return mean_actions, log_std
def proba_distribution(self, mean_actions, log_std, observations, deterministic=False):
variance = th.mm(observations ** 2, self.get_std(log_std) ** 2)
def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False):
# TODO: try without detach
variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2)
self.distribution = Normal(mean_actions, th.sqrt(variance))
if deterministic:
action = self.mode()
else:
action = self.sample(observations)
action = self.sample(latent_pi)
return action, self
def mode(self):
self.gaussian_action = self.distribution.mean
if self.squash_output:
return th.tanh(self.gaussian_action)
return self.gaussian_action
action = self.distribution.mean
if self.bijector is not None:
return self.bijector.forward(action)
return action
def sample(self, observations):
noise = th.mm(observations, self.noise_weights)
self.gaussian_action = self.distribution.mean + noise
if self.squash_output:
return th.tanh(self.gaussian_action)
return self.gaussian_action
def sample(self, latent_pi):
noise = th.mm(latent_pi.detach(), self.exploration_mat)
action = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(action)
return action
def entropy(self):
# TODO: account for the squashing?
return self.distribution.entropy()
def log_prob_from_params(self, mean_actions, log_std, observations):
action, _ = self.proba_distribution(mean_actions, log_std, observations)
def log_prob_from_params(self, mean_actions, log_std, latent_pi):
action, _ = self.proba_distribution(mean_actions, log_std, latent_pi)
log_prob = self.log_prob(action)
return action, log_prob
def log_prob(self, action):
if self.squash_output:
gaussian_action = self.gaussian_action
if self.bijector is not None:
gaussian_action = self.bijector.inverse(action)
else:
gaussian_action = action
# log likelihood for a gaussian
log_prob = self.distribution.log_prob(gaussian_action)
# log_prob = self.distribution.log_prob(action)
if len(log_prob.shape) > 1:
log_prob = log_prob.sum(axis=1)
else:
log_prob = log_prob.sum()
if self.squash_output:
if self.bijector is not None:
# Squash correction (from original SAC implementation)
log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1)
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_action), dim=1)
return log_prob
def make_proba_distribution(action_space, features_dim=None, use_sde=False):
class TanhBijector(object):
def __init__(self, epsilon=1e-6):
super(TanhBijector, self).__init__()
self.epsilon = epsilon
def forward(self, x):
return th.tanh(x)
def inverse(self, action):
"""
Inverse tanh.
From https://github.com/tensorflow/agents:
0.99999997 is the maximum value such that atanh(x) is valid for both
float32 and float64
:param action: (th.Tensor)
:return: (th.Tensor)
"""
# Inverse tanh
# Naive implementation (not stable): 0.5 * torch.log((1 + x ) / (1 - x))
# We use numpy to avoid numerical instability
# Note: Using numpy, we do not keep the gradient
clipped_action = np.clip(action.cpu().numpy(), -0.99999997, 0.99999997)
return th.from_numpy(np.arctanh(clipped_action)).to(action.device)
def log_prob_correction(self, x):
# Squash correction (from original SAC implementation)
return th.log(1 - th.tanh(x) ** 2 + self.epsilon)
def make_proba_distribution(action_space, use_sde=False):
"""
Return an instance of Distribution for the correct type of action space
:param action_space: (Gym Space) the input action space
:param feature_dim: (int) Dimension of the feature vector
:param use_sde: (bool) Force the use of StateDependentNoiseDistribution
instead of DiagGaussianDistribution
:return: (Distribution) the approriate Distribution object
@ -268,7 +301,7 @@ def make_proba_distribution(action_space, features_dim=None, use_sde=False):
if isinstance(action_space, spaces.Box):
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
if use_sde:
return StateDependentNoiseDistribution(features_dim, action_space.shape[0])
return StateDependentNoiseDistribution(action_space.shape[0])
return DiagGaussianDistribution(action_space.shape[0])
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n)

View file

@ -125,7 +125,7 @@ class PPOPolicy(BasePolicy):
self.features_dim = self.obs_dim
self.log_std_init = log_std_init
# Action distribution
self.action_dist = make_proba_distribution(action_space, self.features_dim, use_sde=use_sde)
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde)
self._build(learning_rate)
@ -161,15 +161,15 @@ class PPOPolicy(BasePolicy):
obs = th.FloatTensor(obs).to(self.device)
latent_pi, latent_vf = self._get_latent(obs)
value = self.value_net(latent_vf)
action, action_distribution = self._get_action_dist_from_latent(latent_pi, obs, deterministic=deterministic)
action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
log_prob = action_distribution.log_prob(action)
return action, value, log_prob
def _get_latent(self, obs):
return self.mlp_extractor(self.features_extractor(obs))
def _get_action_dist_from_latent(self, latent, obs, deterministic=False):
mean_actions = self.action_net(latent)
def _get_action_dist_from_latent(self, latent_pi, deterministic=False):
mean_actions = self.action_net(latent_pi)
if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic)
@ -178,16 +178,16 @@ class PPOPolicy(BasePolicy):
return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, obs, deterministic=deterministic)
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi, deterministic=deterministic)
def actor_forward(self, obs, deterministic=False):
latent_pi, _ = self._get_latent(obs)
action, _ = self._get_action_dist_from_latent(latent_pi, obs, deterministic=deterministic)
action, _ = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
return action.detach().cpu().numpy()
def get_policy_stats(self, obs, action):
def get_policy_stats(self, obs, action, deterministic=False):
latent_pi, latent_vf = self._get_latent(obs)
_, action_distribution = self._get_action_dist_from_latent(latent_pi, obs)
_, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
log_prob = action_distribution.log_prob(action)
value = self.value_net(latent_vf)
return value, log_prob, action_distribution.entropy()