diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst
index 0e3242b..a6b8040 100644
--- a/docs/guide/examples.rst
+++ b/docs/guide/examples.rst
@@ -17,6 +17,7 @@ notebooks:
- `Monitor Training and Plotting`_
- `Atari Games`_
- `RL Baselines zoo`_
+- `PyBullet`_
.. - `Hindsight Experience Replay`_
@@ -27,6 +28,7 @@ notebooks:
.. _Atari Games: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
.. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
.. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb
+.. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
.. |colab| image:: ../_static/img/colab.svg
@@ -291,7 +293,7 @@ PyBullet: Normalizing input features
Normalizing input features may be essential to successful training of an RL agent
(by default, images are scaled but not other types of input),
-for instance when training on `PyBullet `_ environments. For that, a wrapper exists and
+for instance when training on `PyBullet `__ environments. For that, a wrapper exists and
will compute a running average and standard deviation of input features (it can do the same for rewards).
@@ -300,6 +302,10 @@ will compute a running average and standard deviation of input features (it can
you need to install pybullet with ``pip install pybullet``
+.. image:: ../_static/img/colab-badge.svg
+ :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
+
+
.. code-block:: python
import gym
diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst
index 45d78d3..0ca6fa9 100644
--- a/docs/misc/changelog.rst
+++ b/docs/misc/changelog.rst
@@ -3,13 +3,15 @@
Changelog
==========
-Pre-Release 0.8.0a2 (WIP)
+Pre-Release 0.8.0a3 (WIP)
------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``AtariWrapper`` and other Atari wrappers were updated to match SB2 ones
- ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi)
+- Refactored ``Critic`` class for ``TD3`` and ``SAC``, it is now called ``ContinuousCritic``
+ and has an additional parameter ``n_critics``
New Features:
^^^^^^^^^^^^^
@@ -40,6 +42,7 @@ Documentation:
- Updated notebook links
- Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake)
- Added Unity reacher to the projects page (@koulakis)
+- Added PyBullet colab notebook
@@ -342,4 +345,4 @@ And all the contributors:
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
-@tirafesi @blurLake @koulakis
+@tirafesi @blurLake @koulakis
diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py
index 16ba652..164a247 100644
--- a/stable_baselines3/common/policies.py
+++ b/stable_baselines3/common/policies.py
@@ -10,7 +10,7 @@ import torch as th
import torch.nn as nn
import numpy as np
-from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space
+from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space, get_action_dim
from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp,
NatureCNN, MlpExtractor)
from stable_baselines3.common.utils import get_device, is_vectorized_observation
@@ -644,6 +644,74 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
optimizer_kwargs)
+class ContinuousCritic(BaseModel):
+ """
+ Critic network(s) for DDPG/SAC/TD3.
+ It represents the action-state value function (Q-value function).
+ Compared to A2C/PPO critics, this one represents the Q-value
+ and takes the continuous action as input. It is concatenated with the state
+ and then fed to the network which outputs a single value: Q(s, a).
+ For more recent algorithms like SAC/TD3, multiple networks
+ are created to give different estimates.
+
+ By default, it creates two critic networks used to reduce overestimation
+ thanks to clipped Q-learning (cf TD3 paper).
+
+ :param observation_space: (gym.spaces.Space) Obervation space
+ :param action_space: (gym.spaces.Space) Action space
+ :param net_arch: ([int]) Network architecture
+ :param features_extractor: (nn.Module) Network to extract features
+ (a CNN when using images, a nn.Flatten() layer otherwise)
+ :param features_dim: (int) Number of features
+ :param activation_fn: (Type[nn.Module]) Activation function
+ :param normalize_images: (bool) Whether to normalize images or not,
+ dividing by 255.0 (True by default)
+ :param device: (Union[th.device, str]) Device on which the code should run.
+ :param n_critics: (int) Number of critic networks to create.
+ """
+
+ def __init__(self, observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ net_arch: List[int],
+ features_extractor: nn.Module,
+ features_dim: int,
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ normalize_images: bool = True,
+ device: Union[th.device, str] = 'auto',
+ n_critics: int = 2):
+ super().__init__(observation_space, action_space,
+ features_extractor=features_extractor,
+ normalize_images=normalize_images,
+ device=device)
+
+ action_dim = get_action_dim(self.action_space)
+
+ self.n_critics = n_critics
+ self.q_networks = []
+ for idx in range(n_critics):
+ q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
+ q_net = nn.Sequential(*q_net)
+ self.add_module(f'qf{idx}', q_net)
+ self.q_networks.append(q_net)
+
+ def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
+ # Learn the features extractor using the policy loss only
+ with th.no_grad():
+ features = self.extract_features(obs)
+ qvalue_input = th.cat([features, actions], dim=1)
+ return tuple(q_net(qvalue_input) for q_net in self.q_networks)
+
+ def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
+ """
+ Only predict the Q-value using the first network.
+ This allows to reduce computation when all the estimates are not needed
+ (e.g. when updating the policy in TD3).
+ """
+ with th.no_grad():
+ features = self.extract_features(obs)
+ return self.q_networks[0](th.cat([features, actions], dim=1))
+
+
def create_sde_features_extractor(features_dim: int,
sde_net_arch: List[int],
activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]:
diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py
index e408622..e409293 100644
--- a/stable_baselines3/sac/policies.py
+++ b/stable_baselines3/sac/policies.py
@@ -5,7 +5,7 @@ import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_action_dim
-from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy, create_sde_features_extractor
+from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor, ContinuousCritic
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
@@ -179,54 +179,6 @@ class Actor(BasePolicy):
return self.forward(observation, deterministic)
-class Critic(BaseModel):
- """
- Critic network (q-value function) for SAC.
-
- :param observation_space: (gym.spaces.Space) Obervation space
- :param action_space: (gym.spaces.Space) Action space
- :param net_arch: ([int]) Network architecture
- :param features_extractor: (nn.Module) Network to extract features
- (a CNN when using images, a nn.Flatten() layer otherwise)
- :param features_dim: (int) Number of features
- :param activation_fn: (Type[nn.Module]) Activation function
- :param normalize_images: (bool) Whether to normalize images or not,
- dividing by 255.0 (True by default)
- :param device: (Union[th.device, str]) Device on which the code should run.
- """
-
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- net_arch: List[int],
- features_extractor: nn.Module,
- features_dim: int,
- activation_fn: Type[nn.Module] = nn.ReLU,
- normalize_images: bool = True,
- device: Union[th.device, str] = 'auto'):
- super(Critic, self).__init__(observation_space, action_space,
- features_extractor=features_extractor,
- normalize_images=normalize_images,
- device=device)
-
- action_dim = get_action_dim(self.action_space)
-
- q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
- self.q1_net = nn.Sequential(*q1_net)
-
- q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
- self.q2_net = nn.Sequential(*q2_net)
-
- self.q_networks = [self.q1_net, self.q2_net]
-
- def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
- # Learn the features extractor using the policy loss only
- # this is much faster
- with th.no_grad():
- features = self.extract_features(obs)
- qvalue_input = th.cat([features, action], dim=1)
- return [q_net(qvalue_input) for q_net in self.q_networks]
-
-
class SACPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for SAC.
@@ -255,6 +207,7 @@ class SACPolicy(BasePolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
+ :param n_critics: (int) Number of critic networks to create.
"""
def __init__(self, observation_space: gym.spaces.Space,
@@ -272,7 +225,8 @@ class SACPolicy(BasePolicy):
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None):
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ n_critics: int = 2):
super(SACPolicy, self).__init__(observation_space, action_space,
device,
features_extractor_class,
@@ -313,6 +267,9 @@ class SACPolicy(BasePolicy):
'clip_mean': clip_mean
}
self.actor_kwargs.update(sde_kwargs)
+ self.critic_kwargs = self.net_args.copy()
+ self.critic_kwargs.update({'n_critics': n_critics})
+
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
@@ -345,6 +302,7 @@ class SACPolicy(BasePolicy):
sde_net_arch=self.actor_kwargs['sde_net_arch'],
use_expln=self.actor_kwargs['use_expln'],
clip_mean=self.actor_kwargs['clip_mean'],
+ n_critics=self.critic_kwargs['n_critics'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
@@ -364,8 +322,8 @@ class SACPolicy(BasePolicy):
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)
- def make_critic(self) -> Critic:
- return Critic(**self.net_args).to(self.device)
+ def make_critic(self) -> ContinuousCritic:
+ return ContinuousCritic(**self.critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
@@ -403,6 +361,7 @@ class CnnPolicy(SACPolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
+ :param n_critics: (int) Number of critic networks to create.
"""
def __init__(self, observation_space: gym.spaces.Space,
@@ -420,7 +379,8 @@ class CnnPolicy(SACPolicy):
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None):
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ n_critics: int = 2):
super(CnnPolicy, self).__init__(observation_space,
action_space,
lr_schedule,
@@ -436,7 +396,8 @@ class CnnPolicy(SACPolicy):
features_extractor_kwargs,
normalize_images,
optimizer_class,
- optimizer_kwargs)
+ optimizer_kwargs,
+ n_critics)
register_policy("MlpPolicy", MlpPolicy)
diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py
index 203abc4..04e20fa 100644
--- a/stable_baselines3/sac/sac.py
+++ b/stable_baselines3/sac/sac.py
@@ -118,7 +118,7 @@ class SAC(OffPolicyAlgorithm):
def _setup_model(self) -> None:
super(SAC, self)._setup_model()
self._create_aliases()
-
+ assert self.critic.n_critics == 2, "SAC only supports `n_critics=2` for now"
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == 'auto':
# automatically set target entropy if needed
diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py
index addadf5..325640f 100644
--- a/stable_baselines3/td3/policies.py
+++ b/stable_baselines3/td3/policies.py
@@ -1,11 +1,11 @@
-from typing import Optional, List, Tuple, Callable, Union, Type, Any, Dict
+from typing import Optional, List, Callable, Union, Type, Any, Dict
import gym
import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_action_dim
-from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
+from stable_baselines3.common.policies import BasePolicy, register_policy, ContinuousCritic
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
@@ -71,57 +71,6 @@ class Actor(BasePolicy):
return self.forward(observation, deterministic=deterministic)
-class Critic(BaseModel):
- """
- Critic network for TD3,
- in fact it represents the action-state value function (Q-value function)
-
- :param observation_space: (gym.spaces.Space) Obervation space
- :param action_space: (gym.spaces.Space) Action space
- :param net_arch: ([int]) Network architecture
- :param features_extractor: (nn.Module) Network to extract features
- (a CNN when using images, a nn.Flatten() layer otherwise)
- :param features_dim: (int) Number of features
- :param activation_fn: (Type[nn.Module]) Activation function
- :param normalize_images: (bool) Whether to normalize images or not,
- dividing by 255.0 (True by default)
- :param device: (Union[th.device, str]) Device on which the code should run.
- """
-
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- net_arch: List[int],
- features_extractor: nn.Module,
- features_dim: int,
- activation_fn: Type[nn.Module] = nn.ReLU,
- normalize_images: bool = True,
- device: Union[th.device, str] = 'auto'):
- super(Critic, self).__init__(observation_space, action_space,
- features_extractor=features_extractor,
- normalize_images=normalize_images,
- device=device)
-
- action_dim = get_action_dim(self.action_space)
-
- q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
- self.q1_net = nn.Sequential(*q1_net)
-
- q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
- self.q2_net = nn.Sequential(*q2_net)
-
- def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
- # Learn the features extractor using the policy loss only
- with th.no_grad():
- features = self.extract_features(obs)
- qvalue_input = th.cat([features, actions], dim=1)
- return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
-
- def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
- with th.no_grad():
- features = self.extract_features(obs)
- return self.q1_net(th.cat([features, actions], dim=1))
-
-
class TD3Policy(BasePolicy):
"""
Policy class (with both actor and critic) for TD3.
@@ -141,6 +90,7 @@ class TD3Policy(BasePolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
+ :param n_critics: (int) Number of critic networks to create.
"""
def __init__(self, observation_space: gym.spaces.Space,
@@ -153,7 +103,8 @@ class TD3Policy(BasePolicy):
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None):
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ n_critics: int = 2):
super(TD3Policy, self).__init__(observation_space, action_space,
device,
features_extractor_class,
@@ -185,6 +136,8 @@ class TD3Policy(BasePolicy):
'normalize_images': normalize_images,
'device': device
}
+ self.critic_kwargs = self.net_args.copy()
+ self.critic_kwargs.update({'n_critics': n_critics})
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
@@ -208,6 +161,7 @@ class TD3Policy(BasePolicy):
data.update(dict(
net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
+ n_critics=self.critic_kwargs['n_critics'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
@@ -219,8 +173,8 @@ class TD3Policy(BasePolicy):
def make_actor(self) -> Actor:
return Actor(**self.net_args).to(self.device)
- def make_critic(self) -> Critic:
- return Critic(**self.net_args).to(self.device)
+ def make_critic(self) -> ContinuousCritic:
+ return ContinuousCritic(**self.critic_kwargs).to(self.device)
def forward(self, observation: th.Tensor, deterministic: bool = False):
return self._predict(observation, deterministic=deterministic)
@@ -251,6 +205,7 @@ class CnnPolicy(TD3Policy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
+ :param n_critics: (int) Number of critic networks to create.
"""
def __init__(self, observation_space: gym.spaces.Space,
@@ -263,7 +218,8 @@ class CnnPolicy(TD3Policy):
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None):
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ n_critics: int = 2):
super(CnnPolicy, self).__init__(observation_space,
action_space,
lr_schedule,
@@ -274,7 +230,8 @@ class CnnPolicy(TD3Policy):
features_extractor_kwargs,
normalize_images,
optimizer_class,
- optimizer_kwargs)
+ optimizer_kwargs,
+ n_critics)
register_policy("MlpPolicy", MlpPolicy)
diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py
index 459e240..13bcc98 100644
--- a/stable_baselines3/td3/td3.py
+++ b/stable_baselines3/td3/td3.py
@@ -96,6 +96,7 @@ class TD3(OffPolicyAlgorithm):
def _setup_model(self) -> None:
super(TD3, self)._setup_model()
self._create_aliases()
+ assert self.critic.n_critics == 2, "TD3 only supports `n_critics=2` for now"
def _create_aliases(self) -> None:
self.actor = self.policy.actor
diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt
index 8db4718..8369211 100644
--- a/stable_baselines3/version.txt
+++ b/stable_baselines3/version.txt
@@ -1 +1 @@
-0.8.0a2
+0.8.0a3