mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
Fix entropy computation
This commit is contained in:
parent
9485b90a41
commit
fd9e73cfb8
6 changed files with 60 additions and 40 deletions
|
|
@ -21,6 +21,7 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Synced callbacks with Stable-Baselines
|
||||
- Fixed colors in `results_plotter`
|
||||
- Fix entropy computation (now summed over action dim)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -34,6 +35,7 @@ Others:
|
|||
- Add test for ``expln``
|
||||
- Renamed ``learning_rate`` to ``lr_schedule``
|
||||
- Add ``version.txt``
|
||||
- Add more tests for distribution
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -2,18 +2,22 @@ import pytest
|
|||
import torch as th
|
||||
|
||||
from torchy_baselines import A2C, PPO
|
||||
from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \
|
||||
StateDependentNoiseDistribution
|
||||
from torchy_baselines.common.distributions import (DiagGaussianDistribution, TanhBijector,
|
||||
StateDependentNoiseDistribution,
|
||||
CategoricalDistribution)
|
||||
from torchy_baselines.common.utils import set_random_seed
|
||||
|
||||
|
||||
# TODO: more tests for the other distributions
|
||||
N_ACTIONS = 2
|
||||
N_FEATURES = 3
|
||||
N_SAMPLES = int(5e6)
|
||||
|
||||
|
||||
def test_bijector():
|
||||
"""
|
||||
Test TanhBijector
|
||||
"""
|
||||
actions = th.ones(5) * 2.0
|
||||
|
||||
bijector = TanhBijector()
|
||||
|
||||
squashed_actions = bijector.forward(actions)
|
||||
|
|
@ -33,16 +37,14 @@ def test_squashed_gaussian(model_class):
|
|||
|
||||
|
||||
def test_sde_distribution():
|
||||
n_samples = int(5e6)
|
||||
n_features = 2
|
||||
n_actions = 1
|
||||
deterministic_actions = th.ones(n_samples, n_actions) * 0.1
|
||||
state = th.ones(n_samples, n_features) * 0.3
|
||||
deterministic_actions = th.ones(N_SAMPLES, n_actions) * 0.1
|
||||
state = th.ones(N_SAMPLES, N_FEATURES) * 0.3
|
||||
dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False)
|
||||
|
||||
set_random_seed(1)
|
||||
_, log_std = dist.proba_distribution_net(n_features)
|
||||
dist.sample_weights(log_std, batch_size=n_samples)
|
||||
_, log_std = dist.proba_distribution_net(N_FEATURES)
|
||||
dist.sample_weights(log_std, batch_size=N_SAMPLES)
|
||||
|
||||
actions, _ = dist.proba_distribution(deterministic_actions, log_std, state)
|
||||
|
||||
|
|
@ -50,10 +52,6 @@ def test_sde_distribution():
|
|||
assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=1e-3)
|
||||
|
||||
|
||||
N_ACTIONS = 1
|
||||
|
||||
|
||||
# TODO: fix for num action > 1
|
||||
# TODO: analytical form for squashed Gaussian?
|
||||
@pytest.mark.parametrize("dist", [
|
||||
DiagGaussianDistribution(N_ACTIONS),
|
||||
|
|
@ -62,19 +60,31 @@ N_ACTIONS = 1
|
|||
def test_entropy(dist):
|
||||
# The entropy can be approximated by averaging the negative log likelihood
|
||||
# mean negative log likelihood == differential entropy
|
||||
n_samples = int(5e6)
|
||||
n_features = 3
|
||||
set_random_seed(1)
|
||||
state = th.rand(n_samples, n_features)
|
||||
deterministic_actions = th.rand(n_samples, N_ACTIONS)
|
||||
_, log_std = dist.proba_distribution_net(n_features, log_std_init=th.log(th.tensor(0.2)))
|
||||
state = th.rand(N_SAMPLES, N_FEATURES)
|
||||
deterministic_actions = th.rand(N_SAMPLES, N_ACTIONS)
|
||||
_, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2)))
|
||||
|
||||
if isinstance(dist, DiagGaussianDistribution):
|
||||
actions, dist = dist.proba_distribution(deterministic_actions, log_std)
|
||||
else:
|
||||
dist.sample_weights(log_std, batch_size=n_samples)
|
||||
dist.sample_weights(log_std, batch_size=N_SAMPLES)
|
||||
actions, dist = dist.proba_distribution(deterministic_actions, log_std, state)
|
||||
|
||||
entropy = dist.entropy()
|
||||
log_prob = dist.log_prob(actions)
|
||||
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
|
||||
|
||||
|
||||
def test_categorical():
|
||||
# The entropy can be approximated by averaging the negative log likelihood
|
||||
# mean negative log likelihood == entropy
|
||||
dist = CategoricalDistribution(N_ACTIONS)
|
||||
set_random_seed(1)
|
||||
state = th.rand(N_SAMPLES, N_FEATURES)
|
||||
action_logits = th.rand(N_SAMPLES, N_ACTIONS)
|
||||
actions, dist = dist.proba_distribution(action_logits)
|
||||
|
||||
entropy = dist.entropy()
|
||||
log_prob = dist.log_prob(actions)
|
||||
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ def test_state_dependent_exploration_grad():
|
|||
|
||||
state = th.rand(n_states, state_dim)
|
||||
mu = th.ones(action_dim)
|
||||
# print(weights.shape, state.shape)
|
||||
noise = th.mm(state, weights)
|
||||
|
||||
action = mu + noise
|
||||
|
|
|
|||
|
|
@ -38,6 +38,22 @@ class Distribution(object):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Continuous actions are usually considered to be independent,
|
||||
so we can sum the components for the ``log_prob``
|
||||
or the entropy.
|
||||
|
||||
:param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,)
|
||||
:return: (th.Tensor) shape: (n_batch,)
|
||||
"""
|
||||
if len(tensor.shape) > 1:
|
||||
tensor = tensor.sum(axis=1)
|
||||
else:
|
||||
tensor = tensor.sum()
|
||||
return tensor
|
||||
|
||||
|
||||
class DiagGaussianDistribution(Distribution):
|
||||
"""
|
||||
Gaussian distribution with diagonal covariance matrix,
|
||||
|
|
@ -95,7 +111,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
return self.distribution.rsample()
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy()
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
|
|
@ -113,18 +129,14 @@ class DiagGaussianDistribution(Distribution):
|
|||
def log_prob(self, action: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Get the log probabilty of an action given a distribution.
|
||||
Note that you must call `proba_distribution()` method
|
||||
Note that you must call ``proba_distribution()`` method
|
||||
before.
|
||||
|
||||
:param action: (th.Tensor)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
log_prob = self.distribution.log_prob(action)
|
||||
if len(log_prob.shape) > 1:
|
||||
log_prob = log_prob.sum(axis=1)
|
||||
else:
|
||||
log_prob = log_prob.sum()
|
||||
return log_prob
|
||||
return sum_independent_dims(log_prob)
|
||||
|
||||
|
||||
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
||||
|
|
@ -243,14 +255,14 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
:param action_dim: (int) Number of continuous actions
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,)
|
||||
:param use_expln: (bool) Use `expln()` function instead of `exp()` to ensure
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, `exp()` is usually enough.
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries.
|
||||
:param learn_features: (bool) Whether to learn features for SDE or not.
|
||||
This will enable gradients to be backpropagated through the features
|
||||
`latent_sde` in the code.
|
||||
``latent_sde`` in the code.
|
||||
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
|
||||
"""
|
||||
|
||||
|
|
@ -396,7 +408,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
# entropy needs to be estimated using -log_prob.mean()
|
||||
if self.bijector is not None:
|
||||
return None
|
||||
return self.distribution.entropy()
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor,
|
||||
log_std: th.Tensor,
|
||||
|
|
@ -412,11 +424,8 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
gaussian_action = action
|
||||
# log likelihood for a gaussian
|
||||
log_prob = self.distribution.log_prob(gaussian_action)
|
||||
|
||||
if len(log_prob.shape) > 1:
|
||||
log_prob = log_prob.sum(axis=1)
|
||||
else:
|
||||
log_prob = log_prob.sum()
|
||||
# Sum along action dim
|
||||
log_prob = sum_independent_dims(log_prob)
|
||||
|
||||
if self.bijector is not None:
|
||||
# Squash correction (from original SAC implementation)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import typing
|
|||
from typing import Optional
|
||||
from copy import deepcopy
|
||||
|
||||
from torchy_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError,\
|
||||
VecEnv, VecEnvWrapper, CloudpickleWrapper
|
||||
from torchy_baselines.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError,
|
||||
VecEnv, VecEnvWrapper, CloudpickleWrapper)
|
||||
from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.2.5
|
||||
0.2.6
|
||||
|
|
|
|||
Loading…
Reference in a new issue