Support for MultiBinary / MultiDiscrete spaces (#13)

* multicategorical dist and test

* fixed List annotation

* bernoulli dist and test

* added distributions to preprocessing (needs testing)

* fixed and tested distributions

* added changelog and fixed ppo policy

* minor fix

* dist fixes, added test_spaces

* clean up

* modified changelog

* additional fixes

* minor changelog mod

* hot encoding fix, flake8 clean up

* lint tests

* preprocessing fix

* fixed bernoulli bug

* removed commented prints

* Update changelog.rst

* included suggested modifications

* linting fix

* increased space dim

* Update doc and tests

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Roland Gavrilescu 2020-05-18 13:42:13 +01:00 committed by GitHub
parent 15ff6d47ee
commit 91adefdb4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 293 additions and 91 deletions

View file

@ -162,8 +162,8 @@ All the following examples can be executed online using Google colab notebooks:
| **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** |
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |

View file

@ -5,14 +5,14 @@ This table displays the rl algorithms that are implemented in the Stable Baselin
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
============ =========== ============ ================
Name ``Box`` ``Discrete`` Multi Processing
============ =========== ============ ================
A2C ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌
TD3 ✔️ ❌ ❌
============ =========== ============ ================
============ =========== ============ ================= =============== ================
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
A2C ✔️ ✔️ ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌ ❌ ❌
TD3 ✔️ ❌ ❌ ❌ ❌
============ =========== ============ ================= =============== ================
.. note::

View file

@ -3,10 +3,9 @@
Changelog
==========
Pre-Release 0.6.0a8 (WIP)
Pre-Release 0.6.0a9 (WIP)
------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Remove State-Dependent Exploration (SDE) support for ``TD3``
@ -17,6 +16,8 @@ New Features:
- Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines)
- Added determinism tests
- Added ``cmd_utils`` and ``atari_wrappers``
- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc)
- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc)
Bug Fixes:
^^^^^^^^^^
@ -227,4 +228,4 @@ And all the contributors:
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta @rolandgvc

View file

@ -28,10 +28,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ❌
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ❌
MultiBinary ❌ ❌
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
============= ====== ===========

View file

@ -38,10 +38,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ❌
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ❌
MultiBinary ❌ ❌
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
============= ====== ===========
Example

View file

@ -58,10 +58,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌
MultiBinary ❌
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========

View file

@ -50,10 +50,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌
MultiBinary ❌
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========

View file

@ -20,6 +20,7 @@ class BaseBuffer(object):
to which the values will be converted
:param n_envs: (int) Number of parallel environments
"""
def __init__(self,
buffer_size: int,
observation_space: spaces.Space,

View file

@ -1,9 +1,8 @@
from typing import Optional, Tuple, Dict, Any
from typing import Optional, Tuple, Dict, Any, List
import gym
import torch as th
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import Normal, Categorical, Bernoulli
from gym import spaces
from stable_baselines3.common.preprocessing import get_action_dim
@ -88,7 +87,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
:return: (th.Tensor) shape: (n_batch,)
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(axis=1)
tensor = tensor.sum(dim=1)
else:
tensor = tensor.sum()
return tensor
@ -292,6 +291,114 @@ class CategoricalDistribution(Distribution):
return self.distribution.log_prob(actions)
class MultiCategoricalDistribution(Distribution):
"""
MultiCategorical distribution for multi discrete actions.
:param action_dims: (List[int]) List of sizes of discrete action spaces
"""
def __init__(self, action_dims: List[int]):
super(MultiCategoricalDistribution, self).__init__()
self.action_dims = action_dims
self.distributions = None
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits (flattened) of the MultiCategorical distribution.
You can then get probabilities using a softmax on each sub-space.
:param latent_dim: (int) Dimension of the last layer
of the policy network (before the action layer)
:return: (nn.Linear)
"""
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution':
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distributions], dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
class BernoulliDistribution(Distribution):
"""
Bernoulli distribution for MultiBinary action spaces.
:param action_dim: (int) Number of binary actions
"""
def __init__(self, action_dims: int):
super(BernoulliDistribution, self).__init__()
self.distribution = None
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Bernoulli distribution.
:param latent_dim: (int) Dimension of the last layer
of the policy network (before the action layer)
:return: (nn.Linear)
"""
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution':
self.distribution = Bernoulli(logits=action_logits)
return self
def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)
def sample(self) -> th.Tensor:
return self.distribution.sample()
def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)
def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)
class StateDependentNoiseDistribution(Distribution):
"""
Distribution class for using generalized State Dependent Exploration (gSDE).
@ -551,10 +658,10 @@ def make_proba_distribution(action_space: gym.spaces.Space,
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiDiscrete):
# return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiBinary):
# return BernoulliDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError("Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."

View file

@ -206,8 +206,8 @@ class BasePolicy(nn.Module):
# Handle the different cases for images
# as PyTorch use channel first format
if is_image_space(self.observation_space):
if (observation.shape == self.observation_space.shape or
observation.shape[1:] == self.observation_space.shape):
if (observation.shape == self.observation_space.shape
or observation.shape[1:] == self.observation_space.shape):
pass
else:
# Try to re-order the channels
@ -279,9 +279,9 @@ class BasePolicy(nn.Module):
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Box environment, please use {} ".format(observation_space.shape) +
"or (n_env, {}) for the observation shape."
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape."
.format(", ".join(map(str, observation_space.shape))))
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
@ -289,30 +289,30 @@ class BasePolicy(nn.Module):
elif len(observation.shape) == 1:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
# TODO: add support for MultiDiscrete and MultiBinary observation spaces
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
# if observation.shape == (len(observation_space.nvec),):
# return False
# elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
# "environment, please use ({},) or ".format(len(observation_space.nvec)) +
# "(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
# elif isinstance(observation_space, gym.spaces.MultiBinary):
# if observation.shape == (observation_space.n,):
# return False
# elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
# "environment, please use ({},) or ".format(observation_space.n) +
# "(n_env, {}) for the observation shape.".format(observation_space.n))
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape.")
elif isinstance(observation_space, gym.spaces.MultiBinary):
if observation.shape == (observation_space.n,):
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
return True
else:
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use ({observation_space.n},) or "
+ f"(n_env, {observation_space.n}) for the observation shape.")
else:
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
.format(observation_space))
raise ValueError("Error: Cannot determine if the observation is vectorized "
+ f" with the space type {observation_space}.")
def _get_data(self) -> Dict[str, Any]:
"""
@ -447,7 +447,7 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
if name not in _policy_registry[base_policy_type]:
raise ValueError(f"Error: unknown policy type {name},"
"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
return _policy_registry[base_policy_type][name]
@ -460,14 +460,10 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None:
:param policy: (Type[BasePolicy]) the policy class
"""
sub_class = None
# For building the doc
try:
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
except AttributeError:
sub_class = str(th.random.randint(100))
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
if sub_class is None:
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")
@ -511,7 +507,6 @@ class MlpExtractor(nn.Module):
device: Union[th.device, str] = 'auto'):
super(MlpExtractor, self).__init__()
device = get_device(device)
shared_net, policy_net, value_net = [], [], []
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network

View file

@ -1,9 +1,9 @@
from typing import Tuple
import numpy as np
import torch as th
import torch.nn.functional as F
from gym import spaces
import numpy as np
def is_image_space(observation_space: spaces.Space,
@ -62,11 +62,21 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space,
if is_image_space(observation_space) and normalize_images:
return obs.float() / 255.0
return obs.float()
elif isinstance(observation_space, spaces.Discrete):
# One hot encoding and convert to float to avoid errors
return F.one_hot(obs.long(), num_classes=observation_space.n).float()
elif isinstance(observation_space, spaces.MultiDiscrete):
# Tensor concatenation of one hot encodings of each Categorical sub-space
return th.cat([F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))],
dim=-1).view(obs.shape[0], sum(observation_space.nvec))
elif isinstance(observation_space, spaces.MultiBinary):
return obs.float()
else:
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
raise NotImplementedError()
@ -82,8 +92,13 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
elif isinstance(observation_space, spaces.Discrete):
# Observation is an int
return 1,
elif isinstance(observation_space, spaces.MultiDiscrete):
# Number of discrete features
return int(len(observation_space.nvec)),
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
return int(observation_space.n),
else:
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
raise NotImplementedError()
@ -95,8 +110,13 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
:param observation_space: (spaces.Space)
:return: (int)
"""
# Use Gym internal method
return spaces.utils.flatdim(observation_space)
# See issue https://github.com/openai/gym/issues/1915
# it may be a problem for Dict/Tuple spaces too...
if isinstance(observation_space, spaces.MultiDiscrete):
return sum(observation_space.nvec)
else:
# Use Gym internal method
return spaces.utils.flatdim(observation_space)
def get_action_dim(action_space: spaces.Space) -> int:
@ -111,5 +131,11 @@ def get_action_dim(action_space: spaces.Space) -> int:
elif isinstance(action_space, spaces.Discrete):
# Action is an int
return 1
elif isinstance(action_space, spaces.MultiDiscrete):
# Number of discrete actions
return int(len(action_space.nvec))
elif isinstance(action_space, spaces.MultiBinary):
# Number of binary actions
return int(action_space.n)
else:
raise NotImplementedError()

View file

@ -11,6 +11,7 @@ from stable_baselines3.common.policies import (BasePolicy, register_policy, MlpE
BaseFeaturesExtractor, FlattenExtractor)
from stable_baselines3.common.distributions import (make_proba_distribution, Distribution,
DiagGaussianDistribution, CategoricalDistribution,
MultiCategoricalDistribution, BernoulliDistribution,
StateDependentNoiseDistribution)
@ -178,6 +179,10 @@ class PPOPolicy(BasePolicy):
log_std_init=self.log_std_init)
elif isinstance(self.action_dist, CategoricalDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, BernoulliDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
# Init weights: use orthogonal initialization
@ -226,6 +231,7 @@ class PPOPolicy(BasePolicy):
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
# Features for sde
latent_sde = latent_pi
if self.sde_features_extractor is not None:
@ -245,11 +251,15 @@ class PPOPolicy(BasePolicy):
if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
else:

View file

@ -157,7 +157,6 @@ class PPO(BaseRLModel):
callback.on_rollout_start()
while n_steps < n_rollout_steps:
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.policy.reset_noise(env.num_envs)
@ -213,7 +212,6 @@ class PPO(BaseRLModel):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long

View file

@ -1 +1 @@
0.6.0a8
0.6.0a9

View file

@ -4,7 +4,8 @@ import torch as th
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector,
StateDependentNoiseDistribution,
CategoricalDistribution, SquashedDiagGaussianDistribution)
CategoricalDistribution, SquashedDiagGaussianDistribution,
MultiCategoricalDistribution, BernoulliDistribution)
from stable_baselines3.common.utils import set_random_seed
@ -85,15 +86,21 @@ def test_entropy(dist):
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
def test_categorical():
categorical_params = [
(CategoricalDistribution(N_ACTIONS), N_ACTIONS),
(MultiCategoricalDistribution([2, 3]), sum([2, 3])),
(BernoulliDistribution(N_ACTIONS), N_ACTIONS)
]
@pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_params)
def test_categorical(dist, CAT_ACTIONS):
# The entropy can be approximated by averaging the negative log likelihood
# mean negative log likelihood == entropy
dist = CategoricalDistribution(N_ACTIONS)
set_random_seed(1)
action_logits = th.rand(N_SAMPLES, N_ACTIONS)
action_logits = th.rand(N_SAMPLES, CAT_ACTIONS)
dist = dist.proba_distribution(action_logits)
actions = dist.get_actions()
entropy = dist.entropy()
log_prob = dist.log_prob(actions)
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=2e-4)
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)

View file

@ -2,17 +2,27 @@ import numpy as np
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv
from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv,
IdentityEnvMultiBinary, IdentityEnvMultiDiscrete)
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.noise import NormalActionNoise
DIM = 4
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_discrete(model_class):
env = IdentityEnv(10)
model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000)
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
env = DummyVecEnv([lambda: env])
model = model_class('MlpPolicy', env, gamma=0.5, seed=1).learn(3000)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
obs = env.reset()
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])

47
tests/test_spaces.py Normal file
View file

@ -0,0 +1,47 @@
import numpy as np
import pytest
import gym
from stable_baselines3 import SAC, TD3
from stable_baselines3.common.evaluation import evaluate_policy
class DummyMultiDiscreteSpace(gym.Env):
def __init__(self, nvec):
super(DummyMultiDiscreteSpace, self).__init__()
self.observation_space = gym.spaces.MultiDiscrete(nvec)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
def step(self, action):
return self.observation_space.sample(), 0.0, False, {}
class DummyMultiBinary(gym.Env):
def __init__(self, n):
super(DummyMultiBinary, self).__init__()
self.observation_space = gym.spaces.MultiBinary(n)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
def step(self, action):
return self.observation_space.sample(), 0.0, False, {}
@pytest.mark.parametrize("model_class", [SAC, TD3])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
def test_identity_spaces(model_class, env):
"""
Additional tests for SAC/TD3 to check observation space support
for MultiDiscrete and MultiBinary.
"""
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64]))
model.learn(total_timesteps=500)
evaluate_policy(model, env, n_eval_episodes=5)