mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-21 22:00:21 +00:00
Merge branch 'feat/sde' into feat/offpolicy-sde
This commit is contained in:
commit
d8a7556d84
4 changed files with 170 additions and 27 deletions
20
tests/test_distributions.py
Normal file
20
tests/test_distributions.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from torchy_baselines.common.distributions import DiagGaussianDistribution, SquashedDiagGaussianDistribution,\
|
||||
CategoricalDistribution, TanhBijector
|
||||
|
||||
# TODO: more tests for the other distributions
|
||||
def test_bijector():
|
||||
"""
|
||||
Test TanhBijector
|
||||
"""
|
||||
actions = th.ones(5) * 2.0
|
||||
|
||||
bijector = TanhBijector()
|
||||
|
||||
squashed_actions = bijector.forward(actions)
|
||||
# Check that the boundaries are not violated
|
||||
assert th.max(th.abs(squashed_actions)) <= 1.0
|
||||
# Check the inverse method
|
||||
assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all()
|
||||
|
|
@ -1,17 +1,25 @@
|
|||
import pytest
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from torch.distributions import Normal
|
||||
|
||||
from torchy_baselines import A2C, TD3
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
|
||||
|
||||
def test_state_dependent_exploration():
|
||||
"""
|
||||
Check that the gradient correspond to the expected one
|
||||
"""
|
||||
n_states = 2
|
||||
state_dim = 3
|
||||
# TODO: fix for action_dim > 1
|
||||
action_dim = 1
|
||||
sigma = th.ones(state_dim, action_dim, requires_grad=True)
|
||||
sigma = th.ones(state_dim, 1, requires_grad=True)
|
||||
# Reduce the number of parameters
|
||||
# sigma_ = th.ones(state_dim, action_dim) * sigma_
|
||||
|
||||
# weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma))
|
||||
th.manual_seed(2)
|
||||
|
|
@ -42,19 +50,13 @@ def test_state_dependent_exploration():
|
|||
|
||||
@pytest.mark.parametrize("model_class", [A2C])
|
||||
def test_state_dependent_noise(model_class):
|
||||
import gym
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
|
||||
# env_id = 'Pendulum-v0'
|
||||
env_id = 'MountainCarContinuous-v0'
|
||||
# env_id = 'LunarLanderContinuous-v2'
|
||||
|
||||
env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), norm_reward=True)
|
||||
eval_env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), training=False, norm_reward=False)
|
||||
model = model_class('MlpPolicy', env, n_steps=200, max_grad_norm=1, use_rms_prop=False,
|
||||
use_sde=True, ent_coef=0.00, verbose=1, create_eval_env=True, learning_rate=3e-4,
|
||||
policy_kwargs=dict(log_std_init=0.0, ortho_init=False, net_arch=[256, dict(pi=[256], vf=[256])]),
|
||||
seed=None)
|
||||
|
||||
model = model_class('MlpPolicy', env, n_steps=200, use_sde=True, ent_coef=0.00, verbose=1, learning_rate=3e-4,
|
||||
policy_kwargs=dict(log_std_init=0.0, ortho_init=False), seed=None)
|
||||
model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -289,6 +289,12 @@ class BaseRLModel(object):
|
|||
raise NotImplementedError()
|
||||
|
||||
def set_random_seed(self, seed=None):
|
||||
"""
|
||||
Set the seed of the pseudo-random generators
|
||||
(python, numpy, pytorch, gym, action_space)
|
||||
|
||||
:param seed: (int)
|
||||
"""
|
||||
if seed is None:
|
||||
return
|
||||
set_random_seed(seed, using_cuda=self.device == th.device('cuda'))
|
||||
|
|
|
|||
|
|
@ -45,6 +45,12 @@ class Distribution(object):
|
|||
|
||||
|
||||
class DiagGaussianDistribution(Distribution):
|
||||
"""
|
||||
Gaussian distribution with diagonal covariance matrix,
|
||||
for continuous actions.
|
||||
|
||||
:param action_dim: (int) Number of continuous actions
|
||||
"""
|
||||
def __init__(self, action_dim):
|
||||
super(DiagGaussianDistribution, self).__init__()
|
||||
self.distribution = None
|
||||
|
|
@ -53,12 +59,29 @@ class DiagGaussianDistribution(Distribution):
|
|||
self.log_std = None
|
||||
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
|
||||
"""
|
||||
Create the layers and parameter that represent the distribution:
|
||||
one output will be the mean of the gaussian, the other parameter will be the
|
||||
standard deviation (log std in fact to allow negative values)
|
||||
|
||||
:param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:return: (nn.Linear, nn.Parameter)
|
||||
"""
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
# TODO: allow action dependent std
|
||||
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init)
|
||||
return mean_actions, log_std
|
||||
|
||||
def proba_distribution(self, mean_actions, log_std, deterministic=False):
|
||||
"""
|
||||
Create and sample for the distribution given its parameters (mean, std)
|
||||
|
||||
:param mean_actions: (th.Tensor)
|
||||
:param log_std: (th.Tensor)
|
||||
:param deterministic: (bool)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
action_std = th.ones_like(mean_actions) * log_std.exp()
|
||||
self.distribution = Normal(mean_actions, action_std)
|
||||
if deterministic:
|
||||
|
|
@ -77,11 +100,27 @@ class DiagGaussianDistribution(Distribution):
|
|||
return self.distribution.entropy()
|
||||
|
||||
def log_prob_from_params(self, mean_actions, log_std):
|
||||
"""
|
||||
Compute the log probabilty of taking an action
|
||||
given the distribution parameters.
|
||||
|
||||
:param mean_actions: (th.Tensor)
|
||||
:param log_std: (th.Tensor)
|
||||
:return: (th.Tensor, th.Tensor)
|
||||
"""
|
||||
action, _ = self.proba_distribution(mean_actions, log_std)
|
||||
log_prob = self.log_prob(action)
|
||||
return action, log_prob
|
||||
|
||||
def log_prob(self, action):
|
||||
"""
|
||||
Get the log probabilty of an action given a distribution.
|
||||
Note that you must call `proba_distribution()` method
|
||||
before.
|
||||
|
||||
:param action: (th.Tensor)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
log_prob = self.distribution.log_prob(action)
|
||||
if len(log_prob.shape) > 1:
|
||||
log_prob = log_prob.sum(axis=1)
|
||||
|
|
@ -91,6 +130,13 @@ class DiagGaussianDistribution(Distribution):
|
|||
|
||||
|
||||
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
||||
"""
|
||||
Gaussian distribution with diagonal covariance matrix,
|
||||
followed by a squashing function (tanh) to ensure bounds.
|
||||
|
||||
:param action_dim: (int) Number of continuous actions
|
||||
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
|
||||
"""
|
||||
def __init__(self, action_dim, epsilon=1e-6):
|
||||
super(SquashedDiagGaussianDistribution, self).__init__(action_dim)
|
||||
# Avoid NaN (prevents division by zero or log of zero)
|
||||
|
|
@ -117,27 +163,40 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
|
||||
def log_prob(self, action, gaussian_action=None):
|
||||
# Inverse tanh
|
||||
# Naive implementation (not stable): 0.5 * torch.log((1 + x ) / (1 - x))
|
||||
# Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
|
||||
# We use numpy to avoid numerical instability
|
||||
if gaussian_action is None:
|
||||
# Clip to avoid NaN
|
||||
clipped_action = np.clip(action.cpu().numpy(), -1.0 + self.epsilon, 1.0 + self.epsilon)
|
||||
gaussian_action = th.from_numpy(np.arctanh(clipped_action)).to(action.device)
|
||||
# It will be clipped to avoid NaN when inversing tanh
|
||||
gaussian_action = TanhBijector.inverse(action)
|
||||
|
||||
# Log likelihood for a gaussian distribution
|
||||
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_action)
|
||||
# Squash correction (from original SAC implementation)
|
||||
# this comes from the fact that tanh is bijective and differentiable
|
||||
log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1)
|
||||
return log_prob
|
||||
|
||||
|
||||
class CategoricalDistribution(Distribution):
|
||||
"""
|
||||
Categorical distribution for discrete actions.
|
||||
|
||||
:param action_dim: (int) Number of discrete actions
|
||||
"""
|
||||
def __init__(self, action_dim):
|
||||
super(CategoricalDistribution, self).__init__()
|
||||
self.distribution = None
|
||||
self.action_dim = action_dim
|
||||
|
||||
def proba_distribution_net(self, latent_dim):
|
||||
"""
|
||||
Create the layer that represents the distribution:
|
||||
it will be the logits of the Categorical distribution.
|
||||
You can then get probabilties using a softmax.
|
||||
|
||||
:param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
|
||||
:return: (nn.Linear)
|
||||
"""
|
||||
action_logits = nn.Linear(latent_dim, self.action_dim)
|
||||
return action_logits
|
||||
|
||||
|
|
@ -169,6 +228,19 @@ class CategoricalDistribution(Distribution):
|
|||
|
||||
|
||||
class StateDependentNoiseDistribution(Distribution):
|
||||
"""
|
||||
Distribution class for using State Dependent Exploration (SDE).
|
||||
It is used to create the noise exploration matrix and
|
||||
compute the log probabilty of an action with that noise.
|
||||
|
||||
:param action_dim: (int) Number of continuous actions
|
||||
:param use_expln: (bool) Use `expln()` function instead of `exp()` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, `exp()` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries.
|
||||
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
|
||||
"""
|
||||
def __init__(self, action_dim, use_expln=False,
|
||||
squash_output=False, epsilon=1e-6):
|
||||
super(StateDependentNoiseDistribution, self).__init__()
|
||||
|
|
@ -186,6 +258,13 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.bijector = None
|
||||
|
||||
def get_std(self, log_std):
|
||||
"""
|
||||
Get the standard deviation from the learned parameter
|
||||
(log of it by default). This ensures that the std is positive.
|
||||
|
||||
:param log_std: (th.Tensor)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
if self.use_expln:
|
||||
# From SDE paper, it allows to keep variance
|
||||
# above zero and prevent it from growing too fast
|
||||
|
|
@ -194,19 +273,44 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
else:
|
||||
return th.log(log_std + 1.0) + 1.0
|
||||
else:
|
||||
# Use normal exponential
|
||||
return th.exp(log_std)
|
||||
|
||||
def sample_weights(self, log_std):
|
||||
"""
|
||||
Sample weights for the noise exploration matrix,
|
||||
using a centered gaussian distribution.
|
||||
|
||||
:param log_std: (th.Tensor)
|
||||
"""
|
||||
# TODO: reduce the number of learned dimensions (cf TD3)
|
||||
self.weights_dist = Normal(th.zeros_like(log_std), self.get_std(log_std))
|
||||
self.exploration_mat = self.weights_dist.rsample()
|
||||
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
|
||||
"""
|
||||
Create the layers and parameter that represent the distribution:
|
||||
one output will be the deterministic action, the other parameter will be the
|
||||
standard deviation of the distribution that control the weights of the noise matrix.
|
||||
|
||||
:param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:return: (nn.Linear, nn.Parameter)
|
||||
"""
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
log_std = nn.Parameter(th.ones(latent_dim, self.action_dim) * log_std_init)
|
||||
self.sample_weights(log_std)
|
||||
return mean_actions, log_std
|
||||
|
||||
def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False):
|
||||
"""
|
||||
Create and sample for the distribution given its parameters (mean, std)
|
||||
|
||||
:param mean_actions: (th.Tensor)
|
||||
:param log_std: (th.Tensor)
|
||||
:param deterministic: (bool)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2)
|
||||
self.distribution = Normal(mean_actions, th.sqrt(variance))
|
||||
|
||||
|
|
@ -258,6 +362,13 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
|
||||
|
||||
class TanhBijector(object):
|
||||
"""
|
||||
Bijective transformation of a probabilty distribution
|
||||
using a squashing function (tanh)
|
||||
TODO: use Pyro instead (https://pyro.ai/)
|
||||
|
||||
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
|
||||
"""
|
||||
def __init__(self, epsilon=1e-6):
|
||||
super(TanhBijector, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
|
@ -265,23 +376,27 @@ class TanhBijector(object):
|
|||
def forward(self, x):
|
||||
return th.tanh(x)
|
||||
|
||||
def inverse(self, action):
|
||||
@staticmethod
|
||||
def atanh(x):
|
||||
"""
|
||||
Inverse of Tanh
|
||||
|
||||
Taken from pyro: https://github.com/pyro-ppl/pyro
|
||||
0.5 * torch.log((1 + x ) / (1 - x))
|
||||
"""
|
||||
return 0.5 * (x.log1p() - (-x).log1p())
|
||||
|
||||
@staticmethod
|
||||
def inverse(y):
|
||||
"""
|
||||
Inverse tanh.
|
||||
|
||||
From https://github.com/tensorflow/agents:
|
||||
0.99999997 is the maximum value such that atanh(x) is valid for both
|
||||
float32 and float64
|
||||
|
||||
:param action: (th.Tensor)
|
||||
:param y: (th.Tensor)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
# Inverse tanh
|
||||
# Naive implementation (not stable): 0.5 * torch.log((1 + x ) / (1 - x))
|
||||
# We use numpy to avoid numerical instability
|
||||
# Note: Using numpy, we do not keep the gradient
|
||||
clipped_action = np.clip(action.cpu().numpy(), -0.99999997, 0.99999997)
|
||||
return th.from_numpy(np.arctanh(clipped_action)).to(action.device)
|
||||
eps = th.finfo(y.dtype).eps
|
||||
# Clip the action to avoid NaN
|
||||
return TanhBijector.atanh(y.clamp(min=-1. + eps, max=1. - eps))
|
||||
|
||||
def log_prob_correction(self, x):
|
||||
# Squash correction (from original SAC implementation)
|
||||
|
|
|
|||
Loading…
Reference in a new issue