diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 02e832a..b8579fc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,6 +32,8 @@ jobs: pip install .[extra,tests,docs] # Use headless version pip install opencv-python-headless + # Tmp fix: ROM missing in the newest atari-py version + pip install atari-py==0.2.5 - name: Build the doc run: | make doc diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 97dabf2..f91468e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.1.0a8 (WIP) +Release 1.1.0a9 (WIP) --------------------------- **Dict observation support, timeout handling and refactored HER** @@ -40,10 +40,12 @@ New Features: to handle gym3-style vectorized environments (@vwxyzjn) - Ignored the terminal observation if the it is not provided by the environment such as the gym3-style vectorized environments. (@vwxyzjn) -- Add policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro) +- Added policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro) - Added support for image observation when using ``HER`` - Added ``replay_buffer_class`` and ``replay_buffer_kwargs`` arguments to off-policy algorithms - Added experimental support to train off-policy algorithms with multiple envs (only SAC supported for now) +- Added ``kl_divergence`` helper for ``Distribution`` classes (@09tangriro) +- Added ``wrapper_kwargs`` argument to ``make_vec_env`` (@amy12xx) Bug Fixes: ^^^^^^^^^^ @@ -60,6 +62,7 @@ Others: - Updated ``env_checker`` to reflect support of dict observation spaces - Added Code of Conduct - Added tests for GAE and lambda return computation +- Updated distribution entropy test (thanks @09tangriro) Documentation: ^^^^^^^^^^^^^^ @@ -688,4 +691,4 @@ And all the contributors: @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray @tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn -@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro +@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index ebe06fa..ca3f0b3 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -17,6 +17,7 @@ class Distribution(ABC): def __init__(self): super(Distribution, self).__init__() + self.distribution = None @abstractmethod def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]: @@ -120,7 +121,6 @@ class DiagGaussianDistribution(Distribution): def __init__(self, action_dim: int): super(DiagGaussianDistribution, self).__init__() - self.distribution = None self.action_dim = action_dim self.mean_actions = None self.log_std = None @@ -255,7 +255,6 @@ class CategoricalDistribution(Distribution): def __init__(self, action_dim: int): super(CategoricalDistribution, self).__init__() - self.distribution = None self.action_dim = action_dim def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -308,7 +307,6 @@ class MultiCategoricalDistribution(Distribution): def __init__(self, action_dims: List[int]): super(MultiCategoricalDistribution, self).__init__() self.action_dims = action_dims - self.distributions = None def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ @@ -325,23 +323,23 @@ class MultiCategoricalDistribution(Distribution): return action_logits def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution": - self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] + self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] return self def log_prob(self, actions: th.Tensor) -> th.Tensor: # Extract each discrete action and compute log prob for their respective distributions return th.stack( - [dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1 + [dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1))], dim=1 ).sum(dim=1) def entropy(self) -> th.Tensor: - return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) + return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1) def sample(self) -> th.Tensor: - return th.stack([dist.sample() for dist in self.distributions], dim=1) + return th.stack([dist.sample() for dist in self.distribution], dim=1) def mode(self) -> th.Tensor: - return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) + return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: # Update the proba distribution @@ -363,7 +361,6 @@ class BernoulliDistribution(Distribution): def __init__(self, action_dims: int): super(BernoulliDistribution, self).__init__() - self.distribution = None self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -437,7 +434,6 @@ class StateDependentNoiseDistribution(Distribution): epsilon: float = 1e-6, ): super(StateDependentNoiseDistribution, self).__init__() - self.distribution = None self.action_dim = action_dim self.latent_sde_dim = None self.mean_actions = None @@ -676,3 +672,28 @@ def make_proba_distribution( f"of type {type(action_space)}." " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary." ) + + +def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor: + """ + Wrapper for the PyTorch implementation of the full form KL Divergence + + :param dist_true: the p distribution + :param dist_pred: the q distribution + :return: KL(dist_true||dist_pred) + """ + # KL Divergence for different distribution types is out of scope + assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type" + + # MultiCategoricalDistribution is not a PyTorch Distribution subclass + # so we need to implement it ourselves! + if isinstance(dist_pred, MultiCategoricalDistribution): + assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space" + return th.stack( + [th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)], + dim=1, + ).sum(dim=1) + + # Use the PyTorch kl_divergence implementation + else: + return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution) diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 177e744..520c50a 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -46,6 +46,7 @@ def make_vec_env( vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, vec_env_kwargs: Optional[Dict[str, Any]] = None, monitor_kwargs: Optional[Dict[str, Any]] = None, + wrapper_kwargs: Optional[Dict[str, Any]] = None, ) -> VecEnv: """ Create a wrapped, monitored ``VecEnv``. @@ -65,11 +66,13 @@ def make_vec_env( :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. + :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor. :return: The wrapped environment """ env_kwargs = {} if env_kwargs is None else env_kwargs vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs + wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs def make_env(rank): def _init(): @@ -89,7 +92,7 @@ def make_vec_env( env = Monitor(env, filename=monitor_path, **monitor_kwargs) # Optionally, wrap the environment with the provided wrapper if wrapper_class is not None: - env = wrapper_class(env) + env = wrapper_class(env, **wrapper_kwargs) return env return _init diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 23a7fa8..1d497a0 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.1.0a8 +1.1.0a9 diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 82c8405..db8e76c 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import pytest import torch as th @@ -10,6 +12,7 @@ from stable_baselines3.common.distributions import ( SquashedDiagGaussianDistribution, StateDependentNoiseDistribution, TanhBijector, + kl_divergence, ) from stable_baselines3.common.utils import set_random_seed @@ -77,13 +80,13 @@ def test_entropy(dist): # The entropy can be approximated by averaging the negative log likelihood # mean negative log likelihood == differential entropy set_random_seed(1) - state = th.rand(N_SAMPLES, N_FEATURES) - deterministic_actions = th.rand(N_SAMPLES, N_ACTIONS) + deterministic_actions = th.rand(1, N_ACTIONS).repeat(N_SAMPLES, 1) _, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2))) if isinstance(dist, DiagGaussianDistribution): dist = dist.proba_distribution(deterministic_actions, log_std) else: + state = th.rand(1, N_FEATURES).repeat(N_SAMPLES, 1) dist.sample_weights(log_std, batch_size=N_SAMPLES) dist = dist.proba_distribution(deterministic_actions, log_std, state) @@ -111,3 +114,76 @@ def test_categorical(dist, CAT_ACTIONS): entropy = dist.entropy() log_prob = dist.log_prob(actions) assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) + + +@pytest.mark.parametrize( + "dist_type", + [ + BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), + CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), + DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), + MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))), + SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), + StateDependentNoiseDistribution(N_ACTIONS).proba_distribution( + th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS]) + ), + ], +) +def test_kl_divergence(dist_type): + set_random_seed(8) + # Test 1: same distribution should have KL Div = 0 + dist1 = dist_type + dist2 = dist_type + # PyTorch implementation of kl_divergence doesn't sum across dimensions + assert th.allclose(kl_divergence(dist1, dist2).sum(), th.tensor(0.0)) + + # Test 2: KL Div = E(Unbiased approx KL Div) + if isinstance(dist_type, CategoricalDistribution): + dist1 = dist_type.proba_distribution(th.rand(N_ACTIONS).repeat(N_SAMPLES, 1)) + # deepcopy needed to assign new memory to new distribution instance + dist2 = deepcopy(dist_type).proba_distribution(th.rand(N_ACTIONS).repeat(N_SAMPLES, 1)) + elif isinstance(dist_type, DiagGaussianDistribution) or isinstance(dist_type, SquashedDiagGaussianDistribution): + mean_actions1 = th.rand(1).repeat(N_SAMPLES, 1) + log_std1 = th.rand(1).repeat(N_SAMPLES, 1) + mean_actions2 = th.rand(1).repeat(N_SAMPLES, 1) + log_std2 = th.rand(1).repeat(N_SAMPLES, 1) + dist1 = dist_type.proba_distribution(mean_actions1, log_std1) + dist2 = deepcopy(dist_type).proba_distribution(mean_actions2, log_std2) + elif isinstance(dist_type, BernoulliDistribution): + dist1 = dist_type.proba_distribution(th.rand(1).repeat(N_SAMPLES, 1)) + dist2 = deepcopy(dist_type).proba_distribution(th.rand(1).repeat(N_SAMPLES, 1)) + elif isinstance(dist_type, MultiCategoricalDistribution): + dist1 = dist_type.proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS])).repeat(N_SAMPLES, 1)) + dist2 = deepcopy(dist_type).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS])).repeat(N_SAMPLES, 1)) + elif isinstance(dist_type, StateDependentNoiseDistribution): + dist1 = StateDependentNoiseDistribution(1) + dist2 = deepcopy(dist1) + state = th.rand(1, N_FEATURES).repeat(N_SAMPLES, 1) + mean_actions1 = th.rand(1).repeat(N_SAMPLES, 1) + mean_actions2 = th.rand(1).repeat(N_SAMPLES, 1) + _, log_std = dist1.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2))) + dist1.sample_weights(log_std, batch_size=N_SAMPLES) + dist2.sample_weights(log_std, batch_size=N_SAMPLES) + dist1 = dist1.proba_distribution(mean_actions1, log_std, state) + dist2 = dist2.proba_distribution(mean_actions2, log_std, state) + + full_kl_div = kl_divergence(dist1, dist2).mean(dim=0) + actions = dist1.get_actions() + approx_kl_div = (dist1.log_prob(actions) - dist2.log_prob(actions)).mean(dim=0) + + assert th.allclose(full_kl_div, approx_kl_div, rtol=5e-2) + + # Test 3 Sanity test with easy Bernoulli distribution + if isinstance(dist_type, BernoulliDistribution): + dist1 = BernoulliDistribution(1).proba_distribution(th.tensor([0.3])) + dist2 = BernoulliDistribution(1).proba_distribution(th.tensor([0.65])) + + full_kl_div = kl_divergence(dist1, dist2) + + actions = th.tensor([0.0, 1.0]) + ad_hoc_kl = th.sum( + th.exp(dist1.distribution.log_prob(actions)) + * (dist1.distribution.log_prob(actions) - dist2.distribution.log_prob(actions)) + ) + + assert th.allclose(full_kl_div, ad_hoc_kl) diff --git a/tests/test_utils.py b/tests/test_utils.py index d9473f5..8aecc54 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,7 +7,7 @@ import pytest import torch as th from stable_baselines3 import A2C, PPO -from stable_baselines3.common.atari_wrappers import ClipRewardEnv +from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor @@ -70,6 +70,11 @@ def test_vec_env_kwargs(): assert env.get_attr("goal_velocity")[0] == 0.11 +def test_vec_env_wrapper_kwargs(): + env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, wrapper_class=MaxAndSkipEnv, wrapper_kwargs={"skip": 3}) + assert env.get_attr("_skip")[0] == 3 + + def test_vec_env_monitor_kwargs(): env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False}) assert env.get_attr("allow_early_resets")[0] is False