From 91adefdb4b9aff2e0c5f2dfad47877a74bcc0de6 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Mon, 18 May 2020 13:42:13 +0100 Subject: [PATCH] Support for MultiBinary / MultiDiscrete spaces (#13) * multicategorical dist and test * fixed List annotation * bernoulli dist and test * added distributions to preprocessing (needs testing) * fixed and tested distributions * added changelog and fixed ppo policy * minor fix * dist fixes, added test_spaces * clean up * modified changelog * additional fixes * minor changelog mod * hot encoding fix, flake8 clean up * lint tests * preprocessing fix * fixed bernoulli bug * removed commented prints * Update changelog.rst * included suggested modifications * linting fix * increased space dim * Update doc and tests Co-authored-by: Antonin RAFFIN --- README.md | 4 +- docs/guide/algos.rst | 16 +-- docs/misc/changelog.rst | 7 +- docs/modules/a2c.rst | 6 +- docs/modules/ppo.rst | 6 +- docs/modules/sac.rst | 6 +- docs/modules/td3.rst | 6 +- stable_baselines3/common/buffers.py | 1 + stable_baselines3/common/distributions.py | 123 ++++++++++++++++++++-- stable_baselines3/common/policies.py | 71 ++++++------- stable_baselines3/common/preprocessing.py | 36 ++++++- stable_baselines3/ppo/policies.py | 14 ++- stable_baselines3/ppo/ppo.py | 2 - stable_baselines3/version.txt | 2 +- tests/test_distributions.py | 19 ++-- tests/test_identity.py | 18 +++- tests/test_spaces.py | 47 +++++++++ 17 files changed, 293 insertions(+), 91 deletions(-) create mode 100644 tests/test_spaces.py diff --git a/README.md b/README.md index 309d524..7031d9f 100644 --- a/README.md +++ b/README.md @@ -162,8 +162,8 @@ All the following examples can be executed online using Google colab notebooks: | **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** | | ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- | -| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | -| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | +| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | | TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 34b97ec..94dc43b 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -5,14 +5,14 @@ This table displays the rl algorithms that are implemented in the Stable Baselin along with some useful characteristics: support for discrete/continuous actions, multiprocessing. -============ =========== ============ ================ -Name ``Box`` ``Discrete`` Multi Processing -============ =========== ============ ================ -A2C ✔️ ✔️ ✔️ -PPO ✔️ ✔️ ✔️ -SAC ✔️ ❌ ❌ -TD3 ✔️ ❌ ❌ -============ =========== ============ ================ +============ =========== ============ ================= =============== ================ +Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing +============ =========== ============ ================= =============== ================ +A2C ✔️ ✔️ ✔️ ✔️ ✔️ +PPO ✔️ ✔️ ✔️ ✔️ ✔️ +SAC ✔️ ❌ ❌ ❌ ❌ +TD3 ✔️ ❌ ❌ ❌ ❌ +============ =========== ============ ================= =============== ================ .. note:: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e270f2e..cbdf36f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,10 +3,9 @@ Changelog ========== -Pre-Release 0.6.0a8 (WIP) +Pre-Release 0.6.0a9 (WIP) ------------------------------ - Breaking Changes: ^^^^^^^^^^^^^^^^^ - Remove State-Dependent Exploration (SDE) support for ``TD3`` @@ -17,6 +16,8 @@ New Features: - Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines) - Added determinism tests - Added ``cmd_utils`` and ``atari_wrappers`` +- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc) +- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc) Bug Fixes: ^^^^^^^^^^ @@ -227,4 +228,4 @@ And all the contributors: @XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching -@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta +@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta @rolandgvc diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 38374f7..096778b 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -28,10 +28,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ✔️ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ ============= ====== =========== diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index fb83c89..22fdf15 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -38,10 +38,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ✔️ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ ============= ====== =========== Example diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index 359df4b..4e77788 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -58,10 +58,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ❌ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ ============= ====== =========== diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index 02ae391..86a939d 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -50,10 +50,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ❌ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ ============= ====== =========== diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 4fb4422..e88110f 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -20,6 +20,7 @@ class BaseBuffer(object): to which the values will be converted :param n_envs: (int) Number of parallel environments """ + def __init__(self, buffer_size: int, observation_space: spaces.Space, diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 18af8bd..f9bb16c 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,9 +1,8 @@ -from typing import Optional, Tuple, Dict, Any - +from typing import Optional, Tuple, Dict, Any, List import gym import torch as th import torch.nn as nn -from torch.distributions import Normal, Categorical +from torch.distributions import Normal, Categorical, Bernoulli from gym import spaces from stable_baselines3.common.preprocessing import get_action_dim @@ -88,7 +87,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: :return: (th.Tensor) shape: (n_batch,) """ if len(tensor.shape) > 1: - tensor = tensor.sum(axis=1) + tensor = tensor.sum(dim=1) else: tensor = tensor.sum() return tensor @@ -292,6 +291,114 @@ class CategoricalDistribution(Distribution): return self.distribution.log_prob(actions) +class MultiCategoricalDistribution(Distribution): + """ + MultiCategorical distribution for multi discrete actions. + + :param action_dims: (List[int]) List of sizes of discrete action spaces + """ + + def __init__(self, action_dims: List[int]): + super(MultiCategoricalDistribution, self).__init__() + self.action_dims = action_dims + self.distributions = None + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits (flattened) of the MultiCategorical distribution. + You can then get probabilities using a softmax on each sub-space. + + :param latent_dim: (int) Dimension of the last layer + of the policy network (before the action layer) + :return: (nn.Linear) + """ + + action_logits = nn.Linear(latent_dim, sum(self.action_dims)) + return action_logits + + def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution': + self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] + return self + + def mode(self) -> th.Tensor: + return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) + + def sample(self) -> th.Tensor: + return th.stack([dist.sample() for dist in self.distributions], dim=1) + + def entropy(self) -> th.Tensor: + return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) + + def actions_from_params(self, action_logits: th.Tensor, + deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + # Extract each discrete action and compute log prob for their respective distributions + return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, + th.unbind(actions, dim=1))], dim=1).sum(dim=1) + + +class BernoulliDistribution(Distribution): + """ + Bernoulli distribution for MultiBinary action spaces. + + :param action_dim: (int) Number of binary actions + """ + + def __init__(self, action_dims: int): + super(BernoulliDistribution, self).__init__() + self.distribution = None + self.action_dims = action_dims + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits of the Bernoulli distribution. + + :param latent_dim: (int) Dimension of the last layer + of the policy network (before the action layer) + :return: (nn.Linear) + """ + action_logits = nn.Linear(latent_dim, self.action_dims) + return action_logits + + def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution': + self.distribution = Bernoulli(logits=action_logits) + return self + + def mode(self) -> th.Tensor: + return th.round(self.distribution.probs) + + def sample(self) -> th.Tensor: + return self.distribution.sample() + + def entropy(self) -> th.Tensor: + return self.distribution.entropy().sum(dim=1) + + def actions_from_params(self, action_logits: th.Tensor, + deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions).sum(dim=1) + + class StateDependentNoiseDistribution(Distribution): """ Distribution class for using generalized State Dependent Exploration (gSDE). @@ -551,10 +658,10 @@ def make_proba_distribution(action_space: gym.spaces.Space, return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) - # elif isinstance(action_space, spaces.MultiDiscrete): - # return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) - # elif isinstance(action_space, spaces.MultiBinary): - # return BernoulliDistribution(action_space.n, **dist_kwargs) + elif isinstance(action_space, spaces.MultiDiscrete): + return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) + elif isinstance(action_space, spaces.MultiBinary): + return BernoulliDistribution(action_space.n, **dist_kwargs) else: raise NotImplementedError("Error: probability distribution, not implemented for action space" f"of type {type(action_space)}." diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 7ae79e2..187ac8b 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -206,8 +206,8 @@ class BasePolicy(nn.Module): # Handle the different cases for images # as PyTorch use channel first format if is_image_space(self.observation_space): - if (observation.shape == self.observation_space.shape or - observation.shape[1:] == self.observation_space.shape): + if (observation.shape == self.observation_space.shape + or observation.shape[1:] == self.observation_space.shape): pass else: # Try to re-order the channels @@ -279,9 +279,9 @@ class BasePolicy(nn.Module): elif observation.shape[1:] == observation_space.shape: return True else: - raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + - "Box environment, please use {} ".format(observation_space.shape) + - "or (n_env, {}) for the observation shape." + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for " + + f"Box environment, please use {observation_space.shape} " + + "or (n_env, {}) for the observation shape." .format(", ".join(map(str, observation_space.shape)))) elif isinstance(observation_space, gym.spaces.Discrete): if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' @@ -289,30 +289,30 @@ class BasePolicy(nn.Module): elif len(observation.shape) == 1: return True else: - raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + - "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") - # TODO: add support for MultiDiscrete and MultiBinary observation spaces - # elif isinstance(observation_space, gym.spaces.MultiDiscrete): - # if observation.shape == (len(observation_space.nvec),): - # return False - # elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): - # return True - # else: - # raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) + - # "environment, please use ({},) or ".format(len(observation_space.nvec)) + - # "(n_env, {}) for the observation shape.".format(len(observation_space.nvec))) - # elif isinstance(observation_space, gym.spaces.MultiBinary): - # if observation.shape == (observation_space.n,): - # return False - # elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: - # return True - # else: - # raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) + - # "environment, please use ({},) or ".format(observation_space.n) + - # "(n_env, {}) for the observation shape.".format(observation_space.n)) + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for " + + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") + + elif isinstance(observation_space, gym.spaces.MultiDiscrete): + if observation.shape == (len(observation_space.nvec),): + return False + elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): + return True + else: + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " + + f"environment, please use ({len(observation_space.nvec)},) or " + + f"(n_env, {len(observation_space.nvec)}) for the observation shape.") + elif isinstance(observation_space, gym.spaces.MultiBinary): + if observation.shape == (observation_space.n,): + return False + elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: + return True + else: + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary " + + f"environment, please use ({observation_space.n},) or " + + f"(n_env, {observation_space.n}) for the observation shape.") else: - raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}." - .format(observation_space)) + raise ValueError("Error: Cannot determine if the observation is vectorized " + + f" with the space type {observation_space}.") def _get_data(self) -> Dict[str, Any]: """ @@ -447,7 +447,7 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ raise ValueError(f"Error: the policy type {base_policy_type} is not registered!") if name not in _policy_registry[base_policy_type]: raise ValueError(f"Error: unknown policy type {name}," - "the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") + f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") return _policy_registry[base_policy_type][name] @@ -460,14 +460,10 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None: :param policy: (Type[BasePolicy]) the policy class """ sub_class = None - # For building the doc - try: - for cls in BasePolicy.__subclasses__(): - if issubclass(policy, cls): - sub_class = cls - break - except AttributeError: - sub_class = str(th.random.randint(100)) + for cls in BasePolicy.__subclasses__(): + if issubclass(policy, cls): + sub_class = cls + break if sub_class is None: raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") @@ -511,7 +507,6 @@ class MlpExtractor(nn.Module): device: Union[th.device, str] = 'auto'): super(MlpExtractor, self).__init__() device = get_device(device) - shared_net, policy_net, value_net = [], [], [] policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network value_only_layers = [] # Layer sizes of the network that only belongs to the value network diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 434355d..849756f 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -1,9 +1,9 @@ from typing import Tuple -import numpy as np import torch as th import torch.nn.functional as F from gym import spaces +import numpy as np def is_image_space(observation_space: spaces.Space, @@ -62,11 +62,21 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, if is_image_space(observation_space) and normalize_images: return obs.float() / 255.0 return obs.float() + elif isinstance(observation_space, spaces.Discrete): # One hot encoding and convert to float to avoid errors return F.one_hot(obs.long(), num_classes=observation_space.n).float() + + elif isinstance(observation_space, spaces.MultiDiscrete): + # Tensor concatenation of one hot encodings of each Categorical sub-space + return th.cat([F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float() + for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))], + dim=-1).view(obs.shape[0], sum(observation_space.nvec)) + + elif isinstance(observation_space, spaces.MultiBinary): + return obs.float() + else: - # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict raise NotImplementedError() @@ -82,8 +92,13 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: elif isinstance(observation_space, spaces.Discrete): # Observation is an int return 1, + elif isinstance(observation_space, spaces.MultiDiscrete): + # Number of discrete features + return int(len(observation_space.nvec)), + elif isinstance(observation_space, spaces.MultiBinary): + # Number of binary features + return int(observation_space.n), else: - # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict raise NotImplementedError() @@ -95,8 +110,13 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int: :param observation_space: (spaces.Space) :return: (int) """ - # Use Gym internal method - return spaces.utils.flatdim(observation_space) + # See issue https://github.com/openai/gym/issues/1915 + # it may be a problem for Dict/Tuple spaces too... + if isinstance(observation_space, spaces.MultiDiscrete): + return sum(observation_space.nvec) + else: + # Use Gym internal method + return spaces.utils.flatdim(observation_space) def get_action_dim(action_space: spaces.Space) -> int: @@ -111,5 +131,11 @@ def get_action_dim(action_space: spaces.Space) -> int: elif isinstance(action_space, spaces.Discrete): # Action is an int return 1 + elif isinstance(action_space, spaces.MultiDiscrete): + # Number of discrete actions + return int(len(action_space.nvec)) + elif isinstance(action_space, spaces.MultiBinary): + # Number of binary actions + return int(action_space.n) else: raise NotImplementedError() diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index 41ec5d3..79625cf 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -11,6 +11,7 @@ from stable_baselines3.common.policies import (BasePolicy, register_policy, MlpE BaseFeaturesExtractor, FlattenExtractor) from stable_baselines3.common.distributions import (make_proba_distribution, Distribution, DiagGaussianDistribution, CategoricalDistribution, + MultiCategoricalDistribution, BernoulliDistribution, StateDependentNoiseDistribution) @@ -178,6 +179,10 @@ class PPOPolicy(BasePolicy): log_std_init=self.log_std_init) elif isinstance(self.action_dist, CategoricalDistribution): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + elif isinstance(self.action_dist, BernoulliDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) # Init weights: use orthogonal initialization @@ -226,6 +231,7 @@ class PPOPolicy(BasePolicy): # Preprocess the observation if needed features = self.extract_features(obs) latent_pi, latent_vf = self.mlp_extractor(features) + # Features for sde latent_sde = latent_pi if self.sde_features_extractor is not None: @@ -245,11 +251,15 @@ class PPOPolicy(BasePolicy): if isinstance(self.action_dist, DiagGaussianDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std) - elif isinstance(self.action_dist, CategoricalDistribution): # Here mean_actions are the logits before the softmax return self.action_dist.proba_distribution(action_logits=mean_actions) - + elif isinstance(self.action_dist, MultiCategoricalDistribution): + # Here mean_actions are the flattened logits + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, BernoulliDistribution): + # Here mean_actions are the logits (before rounding to get the binary actions) + return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) else: diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 2359865..c5ca153 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -157,7 +157,6 @@ class PPO(BaseRLModel): callback.on_rollout_start() while n_steps < n_rollout_steps: - if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.policy.reset_noise(env.num_envs) @@ -213,7 +212,6 @@ class PPO(BaseRLModel): approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(batch_size): - actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index df3ddb5..21c9503 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.6.0a8 +0.6.0a9 diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 150d55e..0461e17 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -4,7 +4,8 @@ import torch as th from stable_baselines3 import A2C, PPO from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector, StateDependentNoiseDistribution, - CategoricalDistribution, SquashedDiagGaussianDistribution) + CategoricalDistribution, SquashedDiagGaussianDistribution, + MultiCategoricalDistribution, BernoulliDistribution) from stable_baselines3.common.utils import set_random_seed @@ -85,15 +86,21 @@ def test_entropy(dist): assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) -def test_categorical(): +categorical_params = [ + (CategoricalDistribution(N_ACTIONS), N_ACTIONS), + (MultiCategoricalDistribution([2, 3]), sum([2, 3])), + (BernoulliDistribution(N_ACTIONS), N_ACTIONS) +] + + +@pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_params) +def test_categorical(dist, CAT_ACTIONS): # The entropy can be approximated by averaging the negative log likelihood # mean negative log likelihood == entropy - dist = CategoricalDistribution(N_ACTIONS) set_random_seed(1) - action_logits = th.rand(N_SAMPLES, N_ACTIONS) + action_logits = th.rand(N_SAMPLES, CAT_ACTIONS) dist = dist.proba_distribution(action_logits) - actions = dist.get_actions() entropy = dist.entropy() log_prob = dist.log_prob(actions) - assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=2e-4) + assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) diff --git a/tests/test_identity.py b/tests/test_identity.py index d937c7e..b41b70c 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -2,17 +2,27 @@ import numpy as np import pytest from stable_baselines3 import A2C, PPO, SAC, TD3 -from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv +from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv, + IdentityEnvMultiBinary, IdentityEnvMultiDiscrete) + +from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.noise import NormalActionNoise +DIM = 4 + + @pytest.mark.parametrize("model_class", [A2C, PPO]) -def test_discrete(model_class): - env = IdentityEnv(10) - model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000) +@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) +def test_discrete(model_class, env): + env = DummyVecEnv([lambda: env]) + model = model_class('MlpPolicy', env, gamma=0.5, seed=1).learn(3000) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) + obs = env.reset() + + assert np.shape(model.predict(obs)[0]) == np.shape(obs) @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3]) diff --git a/tests/test_spaces.py b/tests/test_spaces.py new file mode 100644 index 0000000..dfd4a60 --- /dev/null +++ b/tests/test_spaces.py @@ -0,0 +1,47 @@ +import numpy as np +import pytest +import gym + +from stable_baselines3 import SAC, TD3 +from stable_baselines3.common.evaluation import evaluate_policy + + +class DummyMultiDiscreteSpace(gym.Env): + def __init__(self, nvec): + super(DummyMultiDiscreteSpace, self).__init__() + self.observation_space = gym.spaces.MultiDiscrete(nvec) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + + def reset(self): + return self.observation_space.sample() + + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} + + +class DummyMultiBinary(gym.Env): + def __init__(self, n): + super(DummyMultiBinary, self).__init__() + self.observation_space = gym.spaces.MultiBinary(n) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + + def reset(self): + return self.observation_space.sample() + + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} + + +@pytest.mark.parametrize("model_class", [SAC, TD3]) +@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)]) +def test_identity_spaces(model_class, env): + """ + Additional tests for SAC/TD3 to check observation space support + for MultiDiscrete and MultiBinary. + """ + env = gym.wrappers.TimeLimit(env, max_episode_steps=100) + + model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64])) + model.learn(total_timesteps=500) + + evaluate_policy(model, env, n_eval_episodes=5)