From fd9e73cfb87f8983b40ddb02f1f5031a2420d35c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 19 Mar 2020 10:19:48 +0100 Subject: [PATCH] Fix entropy computation --- docs/misc/changelog.rst | 2 + tests/test_distributions.py | 50 ++++++++++++--------- tests/test_sde.py | 1 - torchy_baselines/common/distributions.py | 41 ++++++++++------- torchy_baselines/common/vec_env/__init__.py | 4 +- torchy_baselines/version.txt | 2 +- 6 files changed, 60 insertions(+), 40 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9332168..5a69f4e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -21,6 +21,7 @@ Bug Fixes: ^^^^^^^^^^ - Synced callbacks with Stable-Baselines - Fixed colors in `results_plotter` +- Fix entropy computation (now summed over action dim) Deprecations: ^^^^^^^^^^^^^ @@ -34,6 +35,7 @@ Others: - Add test for ``expln`` - Renamed ``learning_rate`` to ``lr_schedule`` - Add ``version.txt`` +- Add more tests for distribution Documentation: ^^^^^^^^^^^^^^ diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 7d28ad7..c63900d 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -2,18 +2,22 @@ import pytest import torch as th from torchy_baselines import A2C, PPO -from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \ - StateDependentNoiseDistribution +from torchy_baselines.common.distributions import (DiagGaussianDistribution, TanhBijector, + StateDependentNoiseDistribution, + CategoricalDistribution) from torchy_baselines.common.utils import set_random_seed -# TODO: more tests for the other distributions +N_ACTIONS = 2 +N_FEATURES = 3 +N_SAMPLES = int(5e6) + + def test_bijector(): """ Test TanhBijector """ actions = th.ones(5) * 2.0 - bijector = TanhBijector() squashed_actions = bijector.forward(actions) @@ -33,16 +37,14 @@ def test_squashed_gaussian(model_class): def test_sde_distribution(): - n_samples = int(5e6) - n_features = 2 n_actions = 1 - deterministic_actions = th.ones(n_samples, n_actions) * 0.1 - state = th.ones(n_samples, n_features) * 0.3 + deterministic_actions = th.ones(N_SAMPLES, n_actions) * 0.1 + state = th.ones(N_SAMPLES, N_FEATURES) * 0.3 dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False) set_random_seed(1) - _, log_std = dist.proba_distribution_net(n_features) - dist.sample_weights(log_std, batch_size=n_samples) + _, log_std = dist.proba_distribution_net(N_FEATURES) + dist.sample_weights(log_std, batch_size=N_SAMPLES) actions, _ = dist.proba_distribution(deterministic_actions, log_std, state) @@ -50,10 +52,6 @@ def test_sde_distribution(): assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=1e-3) -N_ACTIONS = 1 - - -# TODO: fix for num action > 1 # TODO: analytical form for squashed Gaussian? @pytest.mark.parametrize("dist", [ DiagGaussianDistribution(N_ACTIONS), @@ -62,19 +60,31 @@ N_ACTIONS = 1 def test_entropy(dist): # The entropy can be approximated by averaging the negative log likelihood # mean negative log likelihood == differential entropy - n_samples = int(5e6) - n_features = 3 set_random_seed(1) - state = th.rand(n_samples, n_features) - deterministic_actions = th.rand(n_samples, N_ACTIONS) - _, log_std = dist.proba_distribution_net(n_features, log_std_init=th.log(th.tensor(0.2))) + state = th.rand(N_SAMPLES, N_FEATURES) + deterministic_actions = th.rand(N_SAMPLES, N_ACTIONS) + _, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2))) if isinstance(dist, DiagGaussianDistribution): actions, dist = dist.proba_distribution(deterministic_actions, log_std) else: - dist.sample_weights(log_std, batch_size=n_samples) + dist.sample_weights(log_std, batch_size=N_SAMPLES) actions, dist = dist.proba_distribution(deterministic_actions, log_std, state) entropy = dist.entropy() log_prob = dist.log_prob(actions) assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) + + +def test_categorical(): + # The entropy can be approximated by averaging the negative log likelihood + # mean negative log likelihood == entropy + dist = CategoricalDistribution(N_ACTIONS) + set_random_seed(1) + state = th.rand(N_SAMPLES, N_FEATURES) + action_logits = th.rand(N_SAMPLES, N_ACTIONS) + actions, dist = dist.proba_distribution(action_logits) + + entropy = dist.entropy() + log_prob = dist.log_prob(actions) + assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4) diff --git a/tests/test_sde.py b/tests/test_sde.py index 96f6c8f..cf3bbdb 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -22,7 +22,6 @@ def test_state_dependent_exploration_grad(): state = th.rand(n_states, state_dim) mu = th.ones(action_dim) - # print(weights.shape, state.shape) noise = th.mm(state, weights) action = mu + noise diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index 6a0ab2f..5d04f83 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -38,6 +38,22 @@ class Distribution(object): raise NotImplementedError +def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: + """ + Continuous actions are usually considered to be independent, + so we can sum the components for the ``log_prob`` + or the entropy. + + :param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,) + :return: (th.Tensor) shape: (n_batch,) + """ + if len(tensor.shape) > 1: + tensor = tensor.sum(axis=1) + else: + tensor = tensor.sum() + return tensor + + class DiagGaussianDistribution(Distribution): """ Gaussian distribution with diagonal covariance matrix, @@ -95,7 +111,7 @@ class DiagGaussianDistribution(Distribution): return self.distribution.rsample() def entropy(self) -> th.Tensor: - return self.distribution.entropy() + return sum_independent_dims(self.distribution.entropy()) def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: """ @@ -113,18 +129,14 @@ class DiagGaussianDistribution(Distribution): def log_prob(self, action: th.Tensor) -> th.Tensor: """ Get the log probabilty of an action given a distribution. - Note that you must call `proba_distribution()` method + Note that you must call ``proba_distribution()`` method before. :param action: (th.Tensor) :return: (th.Tensor) """ log_prob = self.distribution.log_prob(action) - if len(log_prob.shape) > 1: - log_prob = log_prob.sum(axis=1) - else: - log_prob = log_prob.sum() - return log_prob + return sum_independent_dims(log_prob) class SquashedDiagGaussianDistribution(DiagGaussianDistribution): @@ -243,14 +255,14 @@ class StateDependentNoiseDistribution(Distribution): :param action_dim: (int) Number of continuous actions :param full_std: (bool) Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) - :param use_expln: (bool) Use `expln()` function instead of `exp()` to ensure + :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure a positive standard deviation (cf paper). It allows to keep variance - above zero and prevent it from growing too fast. In practice, `exp()` is usually enough. + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: (bool) Whether to squash the output using a tanh function, this allows to ensure boundaries. :param learn_features: (bool) Whether to learn features for SDE or not. This will enable gradients to be backpropagated through the features - `latent_sde` in the code. + ``latent_sde`` in the code. :param epsilon: (float) small value to avoid NaN due to numerical imprecision. """ @@ -396,7 +408,7 @@ class StateDependentNoiseDistribution(Distribution): # entropy needs to be estimated using -log_prob.mean() if self.bijector is not None: return None - return self.distribution.entropy() + return sum_independent_dims(self.distribution.entropy()) def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, @@ -412,11 +424,8 @@ class StateDependentNoiseDistribution(Distribution): gaussian_action = action # log likelihood for a gaussian log_prob = self.distribution.log_prob(gaussian_action) - - if len(log_prob.shape) > 1: - log_prob = log_prob.sum(axis=1) - else: - log_prob = log_prob.sum() + # Sum along action dim + log_prob = sum_independent_dims(log_prob) if self.bijector is not None: # Squash correction (from original SAC implementation) diff --git a/torchy_baselines/common/vec_env/__init__.py b/torchy_baselines/common/vec_env/__init__.py index 2cbb349..10f9288 100644 --- a/torchy_baselines/common/vec_env/__init__.py +++ b/torchy_baselines/common/vec_env/__init__.py @@ -3,8 +3,8 @@ import typing from typing import Optional from copy import deepcopy -from torchy_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError,\ - VecEnv, VecEnvWrapper, CloudpickleWrapper +from torchy_baselines.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError, + VecEnv, VecEnvWrapper, CloudpickleWrapper) from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack diff --git a/torchy_baselines/version.txt b/torchy_baselines/version.txt index 3a4036f..53a75d6 100644 --- a/torchy_baselines/version.txt +++ b/torchy_baselines/version.txt @@ -1 +1 @@ -0.2.5 +0.2.6