KL Divergence Helper Function (#431)

* add kl divergence wrapper

* add test

* update changelog

* black lint

* remove unused import

* Fix ent coef loading for SAC (#429)

* Fix ent coef loading for SAC

* Better fix and add comment

* add 'distribution' to base Distribution class

* add sample test

* revert to plain pytorch implementation

* black reformat

* Update docs/misc/changelog.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Doc update (custom policy + fix her example) (#436)

* isort and black reformat

* float -> bool tensor

* add sanity test

* more concise kl code

* remove outdated comment

* all -> allclose assertion

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Fix PyTorch warning

* Update gSDE entropy test

* Update entropy test

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
Rohan Tangri 2021-05-20 18:01:07 +01:00 committed by GitHub
parent 378d197b00
commit df6f9de8f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 114 additions and 15 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.1.0a7 (WIP)
Release 1.1.0a8 (WIP)
---------------------------
**Dict observation support, timeout handling and refactored HER**
@ -40,9 +40,10 @@ New Features:
to handle gym3-style vectorized environments (@vwxyzjn)
- Ignored the terminal observation if the it is not provided by the environment
such as the gym3-style vectorized environments. (@vwxyzjn)
- Add policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro)
- Added policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro)
- Added support for image observation when using ``HER``
- Added ``replay_buffer_class`` and ``replay_buffer_kwargs`` arguments to off-policy algorithms
- Added ``kl_divergence`` helper for ``Distribution`` classes (@09tangriro)
Bug Fixes:
^^^^^^^^^^
@ -59,6 +60,7 @@ Others:
- Updated ``env_checker`` to reflect support of dict observation spaces
- Added Code of Conduct
- Added tests for GAE and lambda return computation
- Updated distribution entropy test (thanks @09tangriro)
Documentation:
^^^^^^^^^^^^^^

View file

@ -17,6 +17,7 @@ class Distribution(ABC):
def __init__(self):
super(Distribution, self).__init__()
self.distribution = None
@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
@ -120,7 +121,6 @@ class DiagGaussianDistribution(Distribution):
def __init__(self, action_dim: int):
super(DiagGaussianDistribution, self).__init__()
self.distribution = None
self.action_dim = action_dim
self.mean_actions = None
self.log_std = None
@ -255,7 +255,6 @@ class CategoricalDistribution(Distribution):
def __init__(self, action_dim: int):
super(CategoricalDistribution, self).__init__()
self.distribution = None
self.action_dim = action_dim
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
@ -308,7 +307,6 @@ class MultiCategoricalDistribution(Distribution):
def __init__(self, action_dims: List[int]):
super(MultiCategoricalDistribution, self).__init__()
self.action_dims = action_dims
self.distributions = None
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
@ -325,23 +323,23 @@ class MultiCategoricalDistribution(Distribution):
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution":
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack(
[dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1
[dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1))], dim=1
).sum(dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1)
def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distributions], dim=1)
return th.stack([dist.sample() for dist in self.distribution], dim=1)
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
@ -363,7 +361,6 @@ class BernoulliDistribution(Distribution):
def __init__(self, action_dims: int):
super(BernoulliDistribution, self).__init__()
self.distribution = None
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
@ -437,7 +434,6 @@ class StateDependentNoiseDistribution(Distribution):
epsilon: float = 1e-6,
):
super(StateDependentNoiseDistribution, self).__init__()
self.distribution = None
self.action_dim = action_dim
self.latent_sde_dim = None
self.mean_actions = None
@ -676,3 +672,28 @@ def make_proba_distribution(
f"of type {type(action_space)}."
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
)
def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor:
"""
Wrapper for the PyTorch implementation of the full form KL Divergence
:param dist_true: the p distribution
:param dist_pred: the q distribution
:return: KL(dist_true||dist_pred)
"""
# KL Divergence for different distribution types is out of scope
assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type"
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
if isinstance(dist_pred, MultiCategoricalDistribution):
assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space"
return th.stack(
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
dim=1,
).sum(dim=1)
# Use the PyTorch kl_divergence implementation
else:
return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)

View file

@ -1 +1 @@
1.1.0a7
1.1.0a8

View file

@ -1,3 +1,5 @@
from copy import deepcopy
import pytest
import torch as th
@ -10,6 +12,7 @@ from stable_baselines3.common.distributions import (
SquashedDiagGaussianDistribution,
StateDependentNoiseDistribution,
TanhBijector,
kl_divergence,
)
from stable_baselines3.common.utils import set_random_seed
@ -77,13 +80,13 @@ def test_entropy(dist):
# The entropy can be approximated by averaging the negative log likelihood
# mean negative log likelihood == differential entropy
set_random_seed(1)
state = th.rand(N_SAMPLES, N_FEATURES)
deterministic_actions = th.rand(N_SAMPLES, N_ACTIONS)
deterministic_actions = th.rand(1, N_ACTIONS).repeat(N_SAMPLES, 1)
_, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2)))
if isinstance(dist, DiagGaussianDistribution):
dist = dist.proba_distribution(deterministic_actions, log_std)
else:
state = th.rand(1, N_FEATURES).repeat(N_SAMPLES, 1)
dist.sample_weights(log_std, batch_size=N_SAMPLES)
dist = dist.proba_distribution(deterministic_actions, log_std, state)
@ -111,3 +114,76 @@ def test_categorical(dist, CAT_ACTIONS):
entropy = dist.entropy()
log_prob = dist.log_prob(actions)
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
@pytest.mark.parametrize(
"dist_type",
[
BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))),
SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
StateDependentNoiseDistribution(N_ACTIONS).proba_distribution(
th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS])
),
],
)
def test_kl_divergence(dist_type):
set_random_seed(8)
# Test 1: same distribution should have KL Div = 0
dist1 = dist_type
dist2 = dist_type
# PyTorch implementation of kl_divergence doesn't sum across dimensions
assert th.allclose(kl_divergence(dist1, dist2).sum(), th.tensor(0.0))
# Test 2: KL Div = E(Unbiased approx KL Div)
if isinstance(dist_type, CategoricalDistribution):
dist1 = dist_type.proba_distribution(th.rand(N_ACTIONS).repeat(N_SAMPLES, 1))
# deepcopy needed to assign new memory to new distribution instance
dist2 = deepcopy(dist_type).proba_distribution(th.rand(N_ACTIONS).repeat(N_SAMPLES, 1))
elif isinstance(dist_type, DiagGaussianDistribution) or isinstance(dist_type, SquashedDiagGaussianDistribution):
mean_actions1 = th.rand(1).repeat(N_SAMPLES, 1)
log_std1 = th.rand(1).repeat(N_SAMPLES, 1)
mean_actions2 = th.rand(1).repeat(N_SAMPLES, 1)
log_std2 = th.rand(1).repeat(N_SAMPLES, 1)
dist1 = dist_type.proba_distribution(mean_actions1, log_std1)
dist2 = deepcopy(dist_type).proba_distribution(mean_actions2, log_std2)
elif isinstance(dist_type, BernoulliDistribution):
dist1 = dist_type.proba_distribution(th.rand(1).repeat(N_SAMPLES, 1))
dist2 = deepcopy(dist_type).proba_distribution(th.rand(1).repeat(N_SAMPLES, 1))
elif isinstance(dist_type, MultiCategoricalDistribution):
dist1 = dist_type.proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS])).repeat(N_SAMPLES, 1))
dist2 = deepcopy(dist_type).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS])).repeat(N_SAMPLES, 1))
elif isinstance(dist_type, StateDependentNoiseDistribution):
dist1 = StateDependentNoiseDistribution(1)
dist2 = deepcopy(dist1)
state = th.rand(1, N_FEATURES).repeat(N_SAMPLES, 1)
mean_actions1 = th.rand(1).repeat(N_SAMPLES, 1)
mean_actions2 = th.rand(1).repeat(N_SAMPLES, 1)
_, log_std = dist1.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2)))
dist1.sample_weights(log_std, batch_size=N_SAMPLES)
dist2.sample_weights(log_std, batch_size=N_SAMPLES)
dist1 = dist1.proba_distribution(mean_actions1, log_std, state)
dist2 = dist2.proba_distribution(mean_actions2, log_std, state)
full_kl_div = kl_divergence(dist1, dist2).mean(dim=0)
actions = dist1.get_actions()
approx_kl_div = (dist1.log_prob(actions) - dist2.log_prob(actions)).mean(dim=0)
assert th.allclose(full_kl_div, approx_kl_div, rtol=5e-2)
# Test 3 Sanity test with easy Bernoulli distribution
if isinstance(dist_type, BernoulliDistribution):
dist1 = BernoulliDistribution(1).proba_distribution(th.tensor([0.3]))
dist2 = BernoulliDistribution(1).proba_distribution(th.tensor([0.65]))
full_kl_div = kl_divergence(dist1, dist2)
actions = th.tensor([0.0, 1.0])
ad_hoc_kl = th.sum(
th.exp(dist1.distribution.log_prob(actions))
* (dist1.distribution.log_prob(actions) - dist2.distribution.log_prob(actions))
)
assert th.allclose(full_kl_div, ad_hoc_kl)