mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-22 22:10:16 +00:00
Refactor: add distributions
This commit is contained in:
parent
70e1d673a9
commit
ddaafcbc36
3 changed files with 79 additions and 34 deletions
|
|
@ -1,7 +1,8 @@
|
|||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal
|
||||
|
||||
|
||||
class Distribution(object):
|
||||
def __init__(self):
|
||||
super(Distribution, self).__init__()
|
||||
|
|
@ -19,7 +20,7 @@ class Distribution(object):
|
|||
"""
|
||||
Calculates the Kullback-Leibler divergence from the given probabilty distribution
|
||||
|
||||
:param other: ([float]) the distibution to compare with
|
||||
:param other: ([float]) the distribution to compare with
|
||||
:return: (float) the KL divergence of the two distributions
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
@ -41,15 +42,72 @@ class Distribution(object):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class DiagGaussianDistribution(object):
|
||||
"""docstring for DiagGaussianDistribution."""
|
||||
|
||||
def __init__(self):
|
||||
class DiagGaussianDistribution(Distribution):
|
||||
def __init__(self, action_dim):
|
||||
super(DiagGaussianDistribution, self).__init__()
|
||||
self.distribution = None
|
||||
self.action_dim = action_dim
|
||||
self.mean_actions = None
|
||||
self.log_std = None
|
||||
|
||||
def proba_distribution_from_latent(self, latent, init_scale=1.0, init_bias=0.0):
|
||||
self.distribution = Normal()
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init)
|
||||
return mean_actions, log_std
|
||||
|
||||
def proba_distribution(self, mean_actions, log_std, deterministic=False):
|
||||
action_std = th.ones_like(mean_actions) * log_std.exp()
|
||||
self.distribution = Normal(mean_actions, action_std)
|
||||
if deterministic:
|
||||
action = self.mode()
|
||||
else:
|
||||
action = self.sample()
|
||||
return action, self
|
||||
|
||||
def mode(self):
|
||||
return self.distribution.mean
|
||||
|
||||
def sample(self):
|
||||
return self.distribution.rsample()
|
||||
|
||||
def entropy(self):
|
||||
return self.distribution.entropy()
|
||||
|
||||
def log_prob(self, action):
|
||||
log_prob = self.distribution.log_prob(action)
|
||||
if len(log_prob.shape) > 1:
|
||||
log_prob = log_prob.sum(axis=1)
|
||||
else:
|
||||
log_prob = log_prob.sum()
|
||||
return log_prob
|
||||
|
||||
|
||||
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
||||
def __init__(self, action_dim, epsilon=1e-6):
|
||||
super(SquashedDiagGaussianDistribution, self).__init__(action_dim)
|
||||
# Avoid NaN (prevents division by zero or log of zero)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def proba_distribution(self, mean_actions, log_std, deterministic=False):
|
||||
action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std, deterministic)
|
||||
return action, self
|
||||
|
||||
def mode(self):
|
||||
# Squash the output
|
||||
return th.tanh(self.distribution.mean)
|
||||
|
||||
def sample(self):
|
||||
return th.tanh(self.distribution.rsample())
|
||||
|
||||
def log_prob(self, action, gaussian_action=None):
|
||||
# Inverse tanh
|
||||
# Naive implementation (not stable): 0.5 * torch.log((1 + x ) / (1 - x))
|
||||
# We use numpy to avoid numerical instability
|
||||
if gaussian_action is None:
|
||||
gaussian_action = th.from_numpy(np.arctanh(action.cpu().numpy())).to(action.device)
|
||||
|
||||
# Log likelihood for a gaussian distribution
|
||||
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_action)
|
||||
# Squash correction (from original implementation)
|
||||
log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1)
|
||||
return log_prob
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from torch.distributions import Normal
|
|||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp
|
||||
|
||||
from torchy_baselines.common.distributions import DiagGaussianDistribution, SquashedDiagGaussianDistribution
|
||||
|
||||
class PPOPolicy(BasePolicy):
|
||||
def __init__(self, observation_space, action_space,
|
||||
|
|
@ -28,6 +28,9 @@ class PPOPolicy(BasePolicy):
|
|||
}
|
||||
self.shared_net = None
|
||||
self.pi_net, self.vf_net = None, None
|
||||
# Action distribution
|
||||
# self.action_dist = DiagGaussianDistribution(self.action_dim)
|
||||
self.action_dist = SquashedDiagGaussianDistribution(self.action_dim)
|
||||
self._build(learning_rate)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -46,17 +49,18 @@ class PPOPolicy(BasePolicy):
|
|||
vf_net = create_mlp(self.state_dim, output_dim=-1, net_arch=self.net_arch, activation_fn=self.activation_fn)
|
||||
self.vf_net = nn.Sequential(*vf_net).to(self.device)
|
||||
|
||||
self.actor_net = nn.Linear(self.net_arch[-1], self.action_dim)
|
||||
# self.action_net = nn.Linear(self.net_arch[-1], self.action_dim)
|
||||
# self.log_std = nn.Parameter(th.zeros(self.action_dim))
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.net_arch[-1])
|
||||
self.value_net = nn.Linear(self.net_arch[-1], 1)
|
||||
self.log_std = nn.Parameter(th.zeros(self.action_dim))
|
||||
# Init weights: use orthogonal initialization
|
||||
for module in [self.pi_net, self.vf_net, self.actor_net, self.value_net]:
|
||||
for module in [self.pi_net, self.vf_net, self.action_net, self.value_net]:
|
||||
# Values from stable-baselines check why
|
||||
gain = {
|
||||
self.pi_net: np.sqrt(2),
|
||||
self.vf_net: np.sqrt(2),
|
||||
self.shared_net: np.sqrt(2),
|
||||
self.actor_net: 0.01,
|
||||
self.action_net: 0.01,
|
||||
self.value_net: 1
|
||||
}[module]
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
|
|
@ -68,7 +72,7 @@ class PPOPolicy(BasePolicy):
|
|||
latent_pi, latent_vf = self._get_latent(state)
|
||||
value = self.value_net(latent_vf)
|
||||
action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
|
||||
log_prob = self._get_log_prob(action_distribution, action)
|
||||
log_prob = action_distribution.log_prob(action)
|
||||
return action, value, log_prob
|
||||
|
||||
def _get_latent(self, state):
|
||||
|
|
@ -79,24 +83,8 @@ class PPOPolicy(BasePolicy):
|
|||
return self.pi_net(state), self.vf_net(state)
|
||||
|
||||
def _get_action_dist_from_latent(self, latent, deterministic=False):
|
||||
mean_actions = self.actor_net(latent)
|
||||
action_std = th.ones_like(mean_actions) * self.log_std.exp()
|
||||
action_distribution = Normal(mean_actions, action_std)
|
||||
# Sample from the gaussian
|
||||
if deterministic:
|
||||
action = mean_actions
|
||||
else:
|
||||
action = action_distribution.rsample()
|
||||
return action, action_distribution
|
||||
|
||||
@staticmethod
|
||||
def _get_log_prob(action_distribution, action):
|
||||
log_prob = action_distribution.log_prob(action)
|
||||
if len(log_prob.shape) > 1:
|
||||
log_prob = log_prob.sum(axis=1)
|
||||
else:
|
||||
log_prob = log_prob.sum()
|
||||
return log_prob
|
||||
mean_actions = self.action_net(latent)
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic)
|
||||
|
||||
def actor_forward(self, state, deterministic=False):
|
||||
latent_pi, _ = self._get_latent(state)
|
||||
|
|
@ -106,7 +94,7 @@ class PPOPolicy(BasePolicy):
|
|||
def get_policy_stats(self, state, action):
|
||||
latent_pi, latent_vf = self._get_latent(state)
|
||||
_, action_distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
log_prob = self._get_log_prob(action_distribution, action)
|
||||
log_prob = action_distribution.log_prob(action)
|
||||
value = self.value_net(latent_vf)
|
||||
return value, log_prob, action_distribution.entropy()
|
||||
|
||||
|
|
|
|||
|
|
@ -126,7 +126,6 @@ class PPO(BaseRLModel):
|
|||
for replay_data in self.rollout_buffer.get(batch_size):
|
||||
# Unpack
|
||||
state, action, old_values, old_log_prob, advantage, return_batch = replay_data
|
||||
|
||||
values, log_prob, entropy = self.policy.get_policy_stats(state, action)
|
||||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
|
|
|
|||
Loading…
Reference in a new issue