Fix entropy computation

This commit is contained in:
Antonin RAFFIN 2020-03-19 10:19:48 +01:00
parent 9485b90a41
commit fd9e73cfb8
6 changed files with 60 additions and 40 deletions

View file

@ -21,6 +21,7 @@ Bug Fixes:
^^^^^^^^^^
- Synced callbacks with Stable-Baselines
- Fixed colors in `results_plotter`
- Fix entropy computation (now summed over action dim)
Deprecations:
^^^^^^^^^^^^^
@ -34,6 +35,7 @@ Others:
- Add test for ``expln``
- Renamed ``learning_rate`` to ``lr_schedule``
- Add ``version.txt``
- Add more tests for distribution
Documentation:
^^^^^^^^^^^^^^

View file

@ -2,18 +2,22 @@ import pytest
import torch as th
from torchy_baselines import A2C, PPO
from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \
StateDependentNoiseDistribution
from torchy_baselines.common.distributions import (DiagGaussianDistribution, TanhBijector,
StateDependentNoiseDistribution,
CategoricalDistribution)
from torchy_baselines.common.utils import set_random_seed
# TODO: more tests for the other distributions
N_ACTIONS = 2
N_FEATURES = 3
N_SAMPLES = int(5e6)
def test_bijector():
"""
Test TanhBijector
"""
actions = th.ones(5) * 2.0
bijector = TanhBijector()
squashed_actions = bijector.forward(actions)
@ -33,16 +37,14 @@ def test_squashed_gaussian(model_class):
def test_sde_distribution():
n_samples = int(5e6)
n_features = 2
n_actions = 1
deterministic_actions = th.ones(n_samples, n_actions) * 0.1
state = th.ones(n_samples, n_features) * 0.3
deterministic_actions = th.ones(N_SAMPLES, n_actions) * 0.1
state = th.ones(N_SAMPLES, N_FEATURES) * 0.3
dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False)
set_random_seed(1)
_, log_std = dist.proba_distribution_net(n_features)
dist.sample_weights(log_std, batch_size=n_samples)
_, log_std = dist.proba_distribution_net(N_FEATURES)
dist.sample_weights(log_std, batch_size=N_SAMPLES)
actions, _ = dist.proba_distribution(deterministic_actions, log_std, state)
@ -50,10 +52,6 @@ def test_sde_distribution():
assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=1e-3)
N_ACTIONS = 1
# TODO: fix for num action > 1
# TODO: analytical form for squashed Gaussian?
@pytest.mark.parametrize("dist", [
DiagGaussianDistribution(N_ACTIONS),
@ -62,19 +60,31 @@ N_ACTIONS = 1
def test_entropy(dist):
# The entropy can be approximated by averaging the negative log likelihood
# mean negative log likelihood == differential entropy
n_samples = int(5e6)
n_features = 3
set_random_seed(1)
state = th.rand(n_samples, n_features)
deterministic_actions = th.rand(n_samples, N_ACTIONS)
_, log_std = dist.proba_distribution_net(n_features, log_std_init=th.log(th.tensor(0.2)))
state = th.rand(N_SAMPLES, N_FEATURES)
deterministic_actions = th.rand(N_SAMPLES, N_ACTIONS)
_, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2)))
if isinstance(dist, DiagGaussianDistribution):
actions, dist = dist.proba_distribution(deterministic_actions, log_std)
else:
dist.sample_weights(log_std, batch_size=n_samples)
dist.sample_weights(log_std, batch_size=N_SAMPLES)
actions, dist = dist.proba_distribution(deterministic_actions, log_std, state)
entropy = dist.entropy()
log_prob = dist.log_prob(actions)
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
def test_categorical():
# The entropy can be approximated by averaging the negative log likelihood
# mean negative log likelihood == entropy
dist = CategoricalDistribution(N_ACTIONS)
set_random_seed(1)
state = th.rand(N_SAMPLES, N_FEATURES)
action_logits = th.rand(N_SAMPLES, N_ACTIONS)
actions, dist = dist.proba_distribution(action_logits)
entropy = dist.entropy()
log_prob = dist.log_prob(actions)
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4)

View file

@ -22,7 +22,6 @@ def test_state_dependent_exploration_grad():
state = th.rand(n_states, state_dim)
mu = th.ones(action_dim)
# print(weights.shape, state.shape)
noise = th.mm(state, weights)
action = mu + noise

View file

@ -38,6 +38,22 @@ class Distribution(object):
raise NotImplementedError
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
"""
Continuous actions are usually considered to be independent,
so we can sum the components for the ``log_prob``
or the entropy.
:param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,)
:return: (th.Tensor) shape: (n_batch,)
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(axis=1)
else:
tensor = tensor.sum()
return tensor
class DiagGaussianDistribution(Distribution):
"""
Gaussian distribution with diagonal covariance matrix,
@ -95,7 +111,7 @@ class DiagGaussianDistribution(Distribution):
return self.distribution.rsample()
def entropy(self) -> th.Tensor:
return self.distribution.entropy()
return sum_independent_dims(self.distribution.entropy())
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
@ -113,18 +129,14 @@ class DiagGaussianDistribution(Distribution):
def log_prob(self, action: th.Tensor) -> th.Tensor:
"""
Get the log probabilty of an action given a distribution.
Note that you must call `proba_distribution()` method
Note that you must call ``proba_distribution()`` method
before.
:param action: (th.Tensor)
:return: (th.Tensor)
"""
log_prob = self.distribution.log_prob(action)
if len(log_prob.shape) > 1:
log_prob = log_prob.sum(axis=1)
else:
log_prob = log_prob.sum()
return log_prob
return sum_independent_dims(log_prob)
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
@ -243,14 +255,14 @@ class StateDependentNoiseDistribution(Distribution):
:param action_dim: (int) Number of continuous actions
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,)
:param use_expln: (bool) Use `expln()` function instead of `exp()` to ensure
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, `exp()` is usually enough.
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries.
:param learn_features: (bool) Whether to learn features for SDE or not.
This will enable gradients to be backpropagated through the features
`latent_sde` in the code.
``latent_sde`` in the code.
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
"""
@ -396,7 +408,7 @@ class StateDependentNoiseDistribution(Distribution):
# entropy needs to be estimated using -log_prob.mean()
if self.bijector is not None:
return None
return self.distribution.entropy()
return sum_independent_dims(self.distribution.entropy())
def log_prob_from_params(self, mean_actions: th.Tensor,
log_std: th.Tensor,
@ -412,11 +424,8 @@ class StateDependentNoiseDistribution(Distribution):
gaussian_action = action
# log likelihood for a gaussian
log_prob = self.distribution.log_prob(gaussian_action)
if len(log_prob.shape) > 1:
log_prob = log_prob.sum(axis=1)
else:
log_prob = log_prob.sum()
# Sum along action dim
log_prob = sum_independent_dims(log_prob)
if self.bijector is not None:
# Squash correction (from original SAC implementation)

View file

@ -3,8 +3,8 @@ import typing
from typing import Optional
from copy import deepcopy
from torchy_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError,\
VecEnv, VecEnvWrapper, CloudpickleWrapper
from torchy_baselines.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError,
VecEnv, VecEnvWrapper, CloudpickleWrapper)
from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack

View file

@ -1 +1 @@
0.2.5
0.2.6