mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
SDE on latent_pi
This commit is contained in:
parent
862ae666b5
commit
925afe784c
3 changed files with 73 additions and 41 deletions
|
|
@ -112,7 +112,6 @@ class A2C(PPO):
|
|||
# Optimization step
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal, Categorical
|
||||
import torch.nn.functional as F
|
||||
from gym import spaces
|
||||
|
||||
class Distribution(object):
|
||||
|
|
@ -168,20 +169,20 @@ class CategoricalDistribution(Distribution):
|
|||
|
||||
|
||||
class StateDependentNoiseDistribution(Distribution):
|
||||
def __init__(self, features_dim, action_dim, use_expln=False,
|
||||
squash_output=True, epsilon=1e-6):
|
||||
def __init__(self, action_dim, use_expln=False,
|
||||
squash_output=False, epsilon=1e-6):
|
||||
super(StateDependentNoiseDistribution, self).__init__()
|
||||
self.distribution = None
|
||||
self.action_dim = action_dim
|
||||
self.features_dim = features_dim
|
||||
self.mean_actions = None
|
||||
self.log_std = None
|
||||
self.weights_dist = None
|
||||
self.noise_weights = None
|
||||
self.gaussian_action = None
|
||||
self.exploration_mat = None
|
||||
self.use_expln = use_expln
|
||||
self.squash_output = squash_output
|
||||
self.epsilon = epsilon
|
||||
if squash_output:
|
||||
self.bijector = TanhBijector(epsilon)
|
||||
else:
|
||||
self.bijector = None
|
||||
|
||||
def get_std(self, log_std):
|
||||
if self.use_expln:
|
||||
|
|
@ -196,71 +197,103 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
|
||||
def sample_weights(self, log_std):
|
||||
self.weights_dist = Normal(th.zeros_like(log_std), self.get_std(log_std))
|
||||
self.noise_weights = self.weights_dist.rsample()
|
||||
self.exploration_mat = self.weights_dist.rsample()
|
||||
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=-1):
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
# TODO: log_std_init depending on the number of layers?
|
||||
log_std = nn.Parameter(th.ones(self.features_dim, self.action_dim) * log_std_init)
|
||||
log_std = nn.Parameter(th.ones(latent_dim, self.action_dim) * log_std_init)
|
||||
self.sample_weights(log_std)
|
||||
return mean_actions, log_std
|
||||
|
||||
def proba_distribution(self, mean_actions, log_std, observations, deterministic=False):
|
||||
variance = th.mm(observations ** 2, self.get_std(log_std) ** 2)
|
||||
def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False):
|
||||
# TODO: try without detach
|
||||
variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2)
|
||||
self.distribution = Normal(mean_actions, th.sqrt(variance))
|
||||
|
||||
if deterministic:
|
||||
action = self.mode()
|
||||
else:
|
||||
action = self.sample(observations)
|
||||
action = self.sample(latent_pi)
|
||||
return action, self
|
||||
|
||||
def mode(self):
|
||||
self.gaussian_action = self.distribution.mean
|
||||
if self.squash_output:
|
||||
return th.tanh(self.gaussian_action)
|
||||
return self.gaussian_action
|
||||
action = self.distribution.mean
|
||||
if self.bijector is not None:
|
||||
return self.bijector.forward(action)
|
||||
return action
|
||||
|
||||
def sample(self, observations):
|
||||
noise = th.mm(observations, self.noise_weights)
|
||||
self.gaussian_action = self.distribution.mean + noise
|
||||
if self.squash_output:
|
||||
return th.tanh(self.gaussian_action)
|
||||
return self.gaussian_action
|
||||
def sample(self, latent_pi):
|
||||
noise = th.mm(latent_pi.detach(), self.exploration_mat)
|
||||
action = self.distribution.mean + noise
|
||||
if self.bijector is not None:
|
||||
return self.bijector.forward(action)
|
||||
return action
|
||||
|
||||
def entropy(self):
|
||||
# TODO: account for the squashing?
|
||||
return self.distribution.entropy()
|
||||
|
||||
def log_prob_from_params(self, mean_actions, log_std, observations):
|
||||
action, _ = self.proba_distribution(mean_actions, log_std, observations)
|
||||
def log_prob_from_params(self, mean_actions, log_std, latent_pi):
|
||||
action, _ = self.proba_distribution(mean_actions, log_std, latent_pi)
|
||||
log_prob = self.log_prob(action)
|
||||
return action, log_prob
|
||||
|
||||
def log_prob(self, action):
|
||||
if self.squash_output:
|
||||
gaussian_action = self.gaussian_action
|
||||
if self.bijector is not None:
|
||||
gaussian_action = self.bijector.inverse(action)
|
||||
else:
|
||||
gaussian_action = action
|
||||
# log likelihood for a gaussian
|
||||
log_prob = self.distribution.log_prob(gaussian_action)
|
||||
# log_prob = self.distribution.log_prob(action)
|
||||
|
||||
if len(log_prob.shape) > 1:
|
||||
log_prob = log_prob.sum(axis=1)
|
||||
else:
|
||||
log_prob = log_prob.sum()
|
||||
if self.squash_output:
|
||||
|
||||
if self.bijector is not None:
|
||||
# Squash correction (from original SAC implementation)
|
||||
log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1)
|
||||
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_action), dim=1)
|
||||
return log_prob
|
||||
|
||||
|
||||
def make_proba_distribution(action_space, features_dim=None, use_sde=False):
|
||||
class TanhBijector(object):
|
||||
def __init__(self, epsilon=1e-6):
|
||||
super(TanhBijector, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
return th.tanh(x)
|
||||
|
||||
def inverse(self, action):
|
||||
"""
|
||||
Inverse tanh.
|
||||
|
||||
From https://github.com/tensorflow/agents:
|
||||
0.99999997 is the maximum value such that atanh(x) is valid for both
|
||||
float32 and float64
|
||||
|
||||
:param action: (th.Tensor)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
# Inverse tanh
|
||||
# Naive implementation (not stable): 0.5 * torch.log((1 + x ) / (1 - x))
|
||||
# We use numpy to avoid numerical instability
|
||||
# Note: Using numpy, we do not keep the gradient
|
||||
clipped_action = np.clip(action.cpu().numpy(), -0.99999997, 0.99999997)
|
||||
return th.from_numpy(np.arctanh(clipped_action)).to(action.device)
|
||||
|
||||
def log_prob_correction(self, x):
|
||||
# Squash correction (from original SAC implementation)
|
||||
return th.log(1 - th.tanh(x) ** 2 + self.epsilon)
|
||||
|
||||
|
||||
def make_proba_distribution(action_space, use_sde=False):
|
||||
"""
|
||||
Return an instance of Distribution for the correct type of action space
|
||||
|
||||
:param action_space: (Gym Space) the input action space
|
||||
:param feature_dim: (int) Dimension of the feature vector
|
||||
:param use_sde: (bool) Force the use of StateDependentNoiseDistribution
|
||||
instead of DiagGaussianDistribution
|
||||
:return: (Distribution) the approriate Distribution object
|
||||
|
|
@ -268,7 +301,7 @@ def make_proba_distribution(action_space, features_dim=None, use_sde=False):
|
|||
if isinstance(action_space, spaces.Box):
|
||||
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
|
||||
if use_sde:
|
||||
return StateDependentNoiseDistribution(features_dim, action_space.shape[0])
|
||||
return StateDependentNoiseDistribution(action_space.shape[0])
|
||||
return DiagGaussianDistribution(action_space.shape[0])
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n)
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class PPOPolicy(BasePolicy):
|
|||
self.features_dim = self.obs_dim
|
||||
self.log_std_init = log_std_init
|
||||
# Action distribution
|
||||
self.action_dist = make_proba_distribution(action_space, self.features_dim, use_sde=use_sde)
|
||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde)
|
||||
|
||||
self._build(learning_rate)
|
||||
|
||||
|
|
@ -161,15 +161,15 @@ class PPOPolicy(BasePolicy):
|
|||
obs = th.FloatTensor(obs).to(self.device)
|
||||
latent_pi, latent_vf = self._get_latent(obs)
|
||||
value = self.value_net(latent_vf)
|
||||
action, action_distribution = self._get_action_dist_from_latent(latent_pi, obs, deterministic=deterministic)
|
||||
action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
|
||||
log_prob = action_distribution.log_prob(action)
|
||||
return action, value, log_prob
|
||||
|
||||
def _get_latent(self, obs):
|
||||
return self.mlp_extractor(self.features_extractor(obs))
|
||||
|
||||
def _get_action_dist_from_latent(self, latent, obs, deterministic=False):
|
||||
mean_actions = self.action_net(latent)
|
||||
def _get_action_dist_from_latent(self, latent_pi, deterministic=False):
|
||||
mean_actions = self.action_net(latent_pi)
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic)
|
||||
|
|
@ -178,16 +178,16 @@ class PPOPolicy(BasePolicy):
|
|||
return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic)
|
||||
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, obs, deterministic=deterministic)
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi, deterministic=deterministic)
|
||||
|
||||
def actor_forward(self, obs, deterministic=False):
|
||||
latent_pi, _ = self._get_latent(obs)
|
||||
action, _ = self._get_action_dist_from_latent(latent_pi, obs, deterministic=deterministic)
|
||||
action, _ = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
|
||||
return action.detach().cpu().numpy()
|
||||
|
||||
def get_policy_stats(self, obs, action):
|
||||
def get_policy_stats(self, obs, action, deterministic=False):
|
||||
latent_pi, latent_vf = self._get_latent(obs)
|
||||
_, action_distribution = self._get_action_dist_from_latent(latent_pi, obs)
|
||||
_, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
|
||||
log_prob = action_distribution.log_prob(action)
|
||||
value = self.value_net(latent_vf)
|
||||
return value, log_prob, action_distribution.entropy()
|
||||
|
|
|
|||
Loading…
Reference in a new issue