mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Support for MultiBinary / MultiDiscrete spaces (#13)
* multicategorical dist and test * fixed List annotation * bernoulli dist and test * added distributions to preprocessing (needs testing) * fixed and tested distributions * added changelog and fixed ppo policy * minor fix * dist fixes, added test_spaces * clean up * modified changelog * additional fixes * minor changelog mod * hot encoding fix, flake8 clean up * lint tests * preprocessing fix * fixed bernoulli bug * removed commented prints * Update changelog.rst * included suggested modifications * linting fix * increased space dim * Update doc and tests Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
15ff6d47ee
commit
91adefdb4b
17 changed files with 293 additions and 91 deletions
|
|
@ -162,8 +162,8 @@ All the following examples can be executed online using Google colab notebooks:
|
|||
|
||||
| **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** |
|
||||
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
|
||||
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
|
||||
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
|
||||
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
|
||||
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
|
||||
|
||||
|
|
|
|||
|
|
@ -5,14 +5,14 @@ This table displays the rl algorithms that are implemented in the Stable Baselin
|
|||
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
|
||||
|
||||
|
||||
============ =========== ============ ================
|
||||
Name ``Box`` ``Discrete`` Multi Processing
|
||||
============ =========== ============ ================
|
||||
A2C ✔️ ✔️ ✔️
|
||||
PPO ✔️ ✔️ ✔️
|
||||
SAC ✔️ ❌ ❌
|
||||
TD3 ✔️ ❌ ❌
|
||||
============ =========== ============ ================
|
||||
============ =========== ============ ================= =============== ================
|
||||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||
============ =========== ============ ================= =============== ================
|
||||
A2C ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
PPO ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
SAC ✔️ ❌ ❌ ❌ ❌
|
||||
TD3 ✔️ ❌ ❌ ❌ ❌
|
||||
============ =========== ============ ================= =============== ================
|
||||
|
||||
|
||||
.. note::
|
||||
|
|
|
|||
|
|
@ -3,10 +3,9 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.6.0a8 (WIP)
|
||||
Pre-Release 0.6.0a9 (WIP)
|
||||
------------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Remove State-Dependent Exploration (SDE) support for ``TD3``
|
||||
|
|
@ -17,6 +16,8 @@ New Features:
|
|||
- Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines)
|
||||
- Added determinism tests
|
||||
- Added ``cmd_utils`` and ``atari_wrappers``
|
||||
- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc)
|
||||
- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -227,4 +228,4 @@ And all the contributors:
|
|||
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
|
||||
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
|
||||
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
|
||||
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta
|
||||
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta @rolandgvc
|
||||
|
|
|
|||
|
|
@ -28,10 +28,10 @@ Can I use?
|
|||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ❌
|
||||
Discrete ✔️ ✔️
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
MultiDiscrete ✔️ ✔️
|
||||
MultiBinary ✔️ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ Can I use?
|
|||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ❌
|
||||
Discrete ✔️ ✔️
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
MultiDiscrete ✔️ ✔️
|
||||
MultiBinary ✔️ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
Example
|
||||
|
|
|
|||
|
|
@ -58,10 +58,10 @@ Can I use?
|
|||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ❌
|
||||
Discrete ❌ ✔️
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
MultiDiscrete ❌ ✔️
|
||||
MultiBinary ❌ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -50,10 +50,10 @@ Can I use?
|
|||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ❌
|
||||
Discrete ❌ ✔️
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
MultiDiscrete ❌ ✔️
|
||||
MultiBinary ❌ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class BaseBuffer(object):
|
|||
to which the values will be converted
|
||||
:param n_envs: (int) Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
from typing import Optional, Tuple, Dict, Any, List
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal, Categorical
|
||||
from torch.distributions import Normal, Categorical, Bernoulli
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
|
@ -88,7 +87,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
|
|||
:return: (th.Tensor) shape: (n_batch,)
|
||||
"""
|
||||
if len(tensor.shape) > 1:
|
||||
tensor = tensor.sum(axis=1)
|
||||
tensor = tensor.sum(dim=1)
|
||||
else:
|
||||
tensor = tensor.sum()
|
||||
return tensor
|
||||
|
|
@ -292,6 +291,114 @@ class CategoricalDistribution(Distribution):
|
|||
return self.distribution.log_prob(actions)
|
||||
|
||||
|
||||
class MultiCategoricalDistribution(Distribution):
|
||||
"""
|
||||
MultiCategorical distribution for multi discrete actions.
|
||||
|
||||
:param action_dims: (List[int]) List of sizes of discrete action spaces
|
||||
"""
|
||||
|
||||
def __init__(self, action_dims: List[int]):
|
||||
super(MultiCategoricalDistribution, self).__init__()
|
||||
self.action_dims = action_dims
|
||||
self.distributions = None
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
||||
"""
|
||||
Create the layer that represents the distribution:
|
||||
it will be the logits (flattened) of the MultiCategorical distribution.
|
||||
You can then get probabilities using a softmax on each sub-space.
|
||||
|
||||
:param latent_dim: (int) Dimension of the last layer
|
||||
of the policy network (before the action layer)
|
||||
:return: (nn.Linear)
|
||||
"""
|
||||
|
||||
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution':
|
||||
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
return th.stack([dist.sample() for dist in self.distributions], dim=1)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor,
|
||||
deterministic: bool = False) -> th.Tensor:
|
||||
# Update the proba distribution
|
||||
self.proba_distribution(action_logits)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
actions = self.actions_from_params(action_logits)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
# Extract each discrete action and compute log prob for their respective distributions
|
||||
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
|
||||
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
|
||||
|
||||
|
||||
class BernoulliDistribution(Distribution):
|
||||
"""
|
||||
Bernoulli distribution for MultiBinary action spaces.
|
||||
|
||||
:param action_dim: (int) Number of binary actions
|
||||
"""
|
||||
|
||||
def __init__(self, action_dims: int):
|
||||
super(BernoulliDistribution, self).__init__()
|
||||
self.distribution = None
|
||||
self.action_dims = action_dims
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
||||
"""
|
||||
Create the layer that represents the distribution:
|
||||
it will be the logits of the Bernoulli distribution.
|
||||
|
||||
:param latent_dim: (int) Dimension of the last layer
|
||||
of the policy network (before the action layer)
|
||||
:return: (nn.Linear)
|
||||
"""
|
||||
action_logits = nn.Linear(latent_dim, self.action_dims)
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution':
|
||||
self.distribution = Bernoulli(logits=action_logits)
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.round(self.distribution.probs)
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
return self.distribution.sample()
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy().sum(dim=1)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor,
|
||||
deterministic: bool = False) -> th.Tensor:
|
||||
# Update the proba distribution
|
||||
self.proba_distribution(action_logits)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
actions = self.actions_from_params(action_logits)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return self.distribution.log_prob(actions).sum(dim=1)
|
||||
|
||||
|
||||
class StateDependentNoiseDistribution(Distribution):
|
||||
"""
|
||||
Distribution class for using generalized State Dependent Exploration (gSDE).
|
||||
|
|
@ -551,10 +658,10 @@ def make_proba_distribution(action_space: gym.spaces.Space,
|
|||
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
# elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
# return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
|
||||
# elif isinstance(action_space, spaces.MultiBinary):
|
||||
# return BernoulliDistribution(action_space.n, **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiBinary):
|
||||
return BernoulliDistribution(action_space.n, **dist_kwargs)
|
||||
else:
|
||||
raise NotImplementedError("Error: probability distribution, not implemented for action space"
|
||||
f"of type {type(action_space)}."
|
||||
|
|
|
|||
|
|
@ -206,8 +206,8 @@ class BasePolicy(nn.Module):
|
|||
# Handle the different cases for images
|
||||
# as PyTorch use channel first format
|
||||
if is_image_space(self.observation_space):
|
||||
if (observation.shape == self.observation_space.shape or
|
||||
observation.shape[1:] == self.observation_space.shape):
|
||||
if (observation.shape == self.observation_space.shape
|
||||
or observation.shape[1:] == self.observation_space.shape):
|
||||
pass
|
||||
else:
|
||||
# Try to re-order the channels
|
||||
|
|
@ -279,9 +279,9 @@ class BasePolicy(nn.Module):
|
|||
elif observation.shape[1:] == observation_space.shape:
|
||||
return True
|
||||
else:
|
||||
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
||||
"Box environment, please use {} ".format(observation_space.shape) +
|
||||
"or (n_env, {}) for the observation shape."
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
|
||||
+ f"Box environment, please use {observation_space.shape} "
|
||||
+ "or (n_env, {}) for the observation shape."
|
||||
.format(", ".join(map(str, observation_space.shape))))
|
||||
elif isinstance(observation_space, gym.spaces.Discrete):
|
||||
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
|
||||
|
|
@ -289,30 +289,30 @@ class BasePolicy(nn.Module):
|
|||
elif len(observation.shape) == 1:
|
||||
return True
|
||||
else:
|
||||
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
||||
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
||||
# TODO: add support for MultiDiscrete and MultiBinary observation spaces
|
||||
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
|
||||
# if observation.shape == (len(observation_space.nvec),):
|
||||
# return False
|
||||
# elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
|
||||
# return True
|
||||
# else:
|
||||
# raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
|
||||
# "environment, please use ({},) or ".format(len(observation_space.nvec)) +
|
||||
# "(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
|
||||
# elif isinstance(observation_space, gym.spaces.MultiBinary):
|
||||
# if observation.shape == (observation_space.n,):
|
||||
# return False
|
||||
# elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
|
||||
# return True
|
||||
# else:
|
||||
# raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
|
||||
# "environment, please use ({},) or ".format(observation_space.n) +
|
||||
# "(n_env, {}) for the observation shape.".format(observation_space.n))
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
|
||||
+ "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
||||
|
||||
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
|
||||
if observation.shape == (len(observation_space.nvec),):
|
||||
return False
|
||||
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
|
||||
+ f"environment, please use ({len(observation_space.nvec)},) or "
|
||||
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape.")
|
||||
elif isinstance(observation_space, gym.spaces.MultiBinary):
|
||||
if observation.shape == (observation_space.n,):
|
||||
return False
|
||||
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
|
||||
+ f"environment, please use ({observation_space.n},) or "
|
||||
+ f"(n_env, {observation_space.n}) for the observation shape.")
|
||||
else:
|
||||
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
|
||||
.format(observation_space))
|
||||
raise ValueError("Error: Cannot determine if the observation is vectorized "
|
||||
+ f" with the space type {observation_space}.")
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -447,7 +447,7 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
|
|||
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
|
||||
if name not in _policy_registry[base_policy_type]:
|
||||
raise ValueError(f"Error: unknown policy type {name},"
|
||||
"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
|
||||
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
|
||||
return _policy_registry[base_policy_type][name]
|
||||
|
||||
|
||||
|
|
@ -460,14 +460,10 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None:
|
|||
:param policy: (Type[BasePolicy]) the policy class
|
||||
"""
|
||||
sub_class = None
|
||||
# For building the doc
|
||||
try:
|
||||
for cls in BasePolicy.__subclasses__():
|
||||
if issubclass(policy, cls):
|
||||
sub_class = cls
|
||||
break
|
||||
except AttributeError:
|
||||
sub_class = str(th.random.randint(100))
|
||||
for cls in BasePolicy.__subclasses__():
|
||||
if issubclass(policy, cls):
|
||||
sub_class = cls
|
||||
break
|
||||
if sub_class is None:
|
||||
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")
|
||||
|
||||
|
|
@ -511,7 +507,6 @@ class MlpExtractor(nn.Module):
|
|||
device: Union[th.device, str] = 'auto'):
|
||||
super(MlpExtractor, self).__init__()
|
||||
device = get_device(device)
|
||||
|
||||
shared_net, policy_net, value_net = [], [], []
|
||||
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
|
||||
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
|
||||
|
||||
def is_image_space(observation_space: spaces.Space,
|
||||
|
|
@ -62,11 +62,21 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space,
|
|||
if is_image_space(observation_space) and normalize_images:
|
||||
return obs.float() / 255.0
|
||||
return obs.float()
|
||||
|
||||
elif isinstance(observation_space, spaces.Discrete):
|
||||
# One hot encoding and convert to float to avoid errors
|
||||
return F.one_hot(obs.long(), num_classes=observation_space.n).float()
|
||||
|
||||
elif isinstance(observation_space, spaces.MultiDiscrete):
|
||||
# Tensor concatenation of one hot encodings of each Categorical sub-space
|
||||
return th.cat([F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
|
||||
for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))],
|
||||
dim=-1).view(obs.shape[0], sum(observation_space.nvec))
|
||||
|
||||
elif isinstance(observation_space, spaces.MultiBinary):
|
||||
return obs.float()
|
||||
|
||||
else:
|
||||
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
|
@ -82,8 +92,13 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
|
|||
elif isinstance(observation_space, spaces.Discrete):
|
||||
# Observation is an int
|
||||
return 1,
|
||||
elif isinstance(observation_space, spaces.MultiDiscrete):
|
||||
# Number of discrete features
|
||||
return int(len(observation_space.nvec)),
|
||||
elif isinstance(observation_space, spaces.MultiBinary):
|
||||
# Number of binary features
|
||||
return int(observation_space.n),
|
||||
else:
|
||||
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
|
@ -95,8 +110,13 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
|
|||
:param observation_space: (spaces.Space)
|
||||
:return: (int)
|
||||
"""
|
||||
# Use Gym internal method
|
||||
return spaces.utils.flatdim(observation_space)
|
||||
# See issue https://github.com/openai/gym/issues/1915
|
||||
# it may be a problem for Dict/Tuple spaces too...
|
||||
if isinstance(observation_space, spaces.MultiDiscrete):
|
||||
return sum(observation_space.nvec)
|
||||
else:
|
||||
# Use Gym internal method
|
||||
return spaces.utils.flatdim(observation_space)
|
||||
|
||||
|
||||
def get_action_dim(action_space: spaces.Space) -> int:
|
||||
|
|
@ -111,5 +131,11 @@ def get_action_dim(action_space: spaces.Space) -> int:
|
|||
elif isinstance(action_space, spaces.Discrete):
|
||||
# Action is an int
|
||||
return 1
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
# Number of discrete actions
|
||||
return int(len(action_space.nvec))
|
||||
elif isinstance(action_space, spaces.MultiBinary):
|
||||
# Number of binary actions
|
||||
return int(action_space.n)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from stable_baselines3.common.policies import (BasePolicy, register_policy, MlpE
|
|||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
from stable_baselines3.common.distributions import (make_proba_distribution, Distribution,
|
||||
DiagGaussianDistribution, CategoricalDistribution,
|
||||
MultiCategoricalDistribution, BernoulliDistribution,
|
||||
StateDependentNoiseDistribution)
|
||||
|
||||
|
||||
|
|
@ -178,6 +179,10 @@ class PPOPolicy(BasePolicy):
|
|||
log_std_init=self.log_std_init)
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
elif isinstance(self.action_dist, BernoulliDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
|
||||
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
|
||||
# Init weights: use orthogonal initialization
|
||||
|
|
@ -226,6 +231,7 @@ class PPOPolicy(BasePolicy):
|
|||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
|
||||
# Features for sde
|
||||
latent_sde = latent_pi
|
||||
if self.sde_features_extractor is not None:
|
||||
|
|
@ -245,11 +251,15 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
||||
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
# Here mean_actions are the logits before the softmax
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
|
||||
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
||||
# Here mean_actions are the flattened logits
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
elif isinstance(self.action_dist, BernoulliDistribution):
|
||||
# Here mean_actions are the logits (before rounding to get the binary actions)
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -157,7 +157,6 @@ class PPO(BaseRLModel):
|
|||
callback.on_rollout_start()
|
||||
|
||||
while n_steps < n_rollout_steps:
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
|
@ -213,7 +212,6 @@ class PPO(BaseRLModel):
|
|||
approx_kl_divs = []
|
||||
# Do a complete pass on the rollout buffer
|
||||
for rollout_data in self.rollout_buffer.get(batch_size):
|
||||
|
||||
actions = rollout_data.actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action from float to long
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.6.0a8
|
||||
0.6.0a9
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ import torch as th
|
|||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector,
|
||||
StateDependentNoiseDistribution,
|
||||
CategoricalDistribution, SquashedDiagGaussianDistribution)
|
||||
CategoricalDistribution, SquashedDiagGaussianDistribution,
|
||||
MultiCategoricalDistribution, BernoulliDistribution)
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
||||
|
||||
|
|
@ -85,15 +86,21 @@ def test_entropy(dist):
|
|||
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
|
||||
|
||||
|
||||
def test_categorical():
|
||||
categorical_params = [
|
||||
(CategoricalDistribution(N_ACTIONS), N_ACTIONS),
|
||||
(MultiCategoricalDistribution([2, 3]), sum([2, 3])),
|
||||
(BernoulliDistribution(N_ACTIONS), N_ACTIONS)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_params)
|
||||
def test_categorical(dist, CAT_ACTIONS):
|
||||
# The entropy can be approximated by averaging the negative log likelihood
|
||||
# mean negative log likelihood == entropy
|
||||
dist = CategoricalDistribution(N_ACTIONS)
|
||||
set_random_seed(1)
|
||||
action_logits = th.rand(N_SAMPLES, N_ACTIONS)
|
||||
action_logits = th.rand(N_SAMPLES, CAT_ACTIONS)
|
||||
dist = dist.proba_distribution(action_logits)
|
||||
|
||||
actions = dist.get_actions()
|
||||
entropy = dist.entropy()
|
||||
log_prob = dist.log_prob(actions)
|
||||
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=2e-4)
|
||||
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
|
||||
|
|
|
|||
|
|
@ -2,17 +2,27 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from stable_baselines3 import A2C, PPO, SAC, TD3
|
||||
from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv
|
||||
from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv,
|
||||
IdentityEnvMultiBinary, IdentityEnvMultiDiscrete)
|
||||
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
|
||||
|
||||
DIM = 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
def test_discrete(model_class):
|
||||
env = IdentityEnv(10)
|
||||
model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000)
|
||||
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
|
||||
def test_discrete(model_class, env):
|
||||
env = DummyVecEnv([lambda: env])
|
||||
model = model_class('MlpPolicy', env, gamma=0.5, seed=1).learn(3000)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
|
||||
obs = env.reset()
|
||||
|
||||
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
|
||||
|
|
|
|||
47
tests/test_spaces.py
Normal file
47
tests/test_spaces.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import gym
|
||||
|
||||
from stable_baselines3 import SAC, TD3
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
|
||||
class DummyMultiDiscreteSpace(gym.Env):
|
||||
def __init__(self, nvec):
|
||||
super(DummyMultiDiscreteSpace, self).__init__()
|
||||
self.observation_space = gym.spaces.MultiDiscrete(nvec)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, {}
|
||||
|
||||
|
||||
class DummyMultiBinary(gym.Env):
|
||||
def __init__(self, n):
|
||||
super(DummyMultiBinary, self).__init__()
|
||||
self.observation_space = gym.spaces.MultiBinary(n)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3])
|
||||
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
|
||||
def test_identity_spaces(model_class, env):
|
||||
"""
|
||||
Additional tests for SAC/TD3 to check observation space support
|
||||
for MultiDiscrete and MultiBinary.
|
||||
"""
|
||||
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
|
||||
|
||||
model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64]))
|
||||
model.learn(total_timesteps=500)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=5)
|
||||
Loading…
Reference in a new issue