From 3756d05f7265b8023b1987d9d94a25ec15dbcc1f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 7 Jul 2020 00:02:51 +0200 Subject: [PATCH] Refactored ContinuousCritic for SAC/TD3 (#78) * Refactored ContinuousCritic for SAC/TD3 * Address comments * Add pybullet notebook --- docs/guide/examples.rst | 8 ++- docs/misc/changelog.rst | 7 ++- stable_baselines3/common/policies.py | 70 +++++++++++++++++++++++++- stable_baselines3/sac/policies.py | 69 ++++++-------------------- stable_baselines3/sac/sac.py | 2 +- stable_baselines3/td3/policies.py | 73 ++++++---------------------- stable_baselines3/td3/td3.py | 1 + stable_baselines3/version.txt | 2 +- 8 files changed, 114 insertions(+), 118 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 0e3242b..a6b8040 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -17,6 +17,7 @@ notebooks: - `Monitor Training and Plotting`_ - `Atari Games`_ - `RL Baselines zoo`_ +- `PyBullet`_ .. - `Hindsight Experience Replay`_ @@ -27,6 +28,7 @@ notebooks: .. _Atari Games: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb .. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb .. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb +.. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb .. |colab| image:: ../_static/img/colab.svg @@ -291,7 +293,7 @@ PyBullet: Normalizing input features Normalizing input features may be essential to successful training of an RL agent (by default, images are scaled but not other types of input), -for instance when training on `PyBullet `_ environments. For that, a wrapper exists and +for instance when training on `PyBullet `__ environments. For that, a wrapper exists and will compute a running average and standard deviation of input features (it can do the same for rewards). @@ -300,6 +302,10 @@ will compute a running average and standard deviation of input features (it can you need to install pybullet with ``pip install pybullet`` +.. image:: ../_static/img/colab-badge.svg + :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb + + .. code-block:: python import gym diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 45d78d3..0ca6fa9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,13 +3,15 @@ Changelog ========== -Pre-Release 0.8.0a2 (WIP) +Pre-Release 0.8.0a3 (WIP) ------------------------------ Breaking Changes: ^^^^^^^^^^^^^^^^^ - ``AtariWrapper`` and other Atari wrappers were updated to match SB2 ones - ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi) +- Refactored ``Critic`` class for ``TD3`` and ``SAC``, it is now called ``ContinuousCritic`` + and has an additional parameter ``n_critics`` New Features: ^^^^^^^^^^^^^ @@ -40,6 +42,7 @@ Documentation: - Updated notebook links - Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake) - Added Unity reacher to the projects page (@koulakis) +- Added PyBullet colab notebook @@ -342,4 +345,4 @@ And all the contributors: @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 -@tirafesi @blurLake @koulakis +@tirafesi @blurLake @koulakis diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 6d4c87f..ba4797f 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -6,7 +6,7 @@ import torch as th import torch.nn as nn import numpy as np -from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space +from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space, get_action_dim from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp, NatureCNN, MlpExtractor) from stable_baselines3.common.utils import get_device, is_vectorized_observation @@ -617,6 +617,74 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): optimizer_kwargs) +class ContinuousCritic(BasePolicy): + """ + Critic network(s) for DDPG/SAC/TD3. + It represents the action-state value function (Q-value function). + Compared to A2C/PPO critics, this one represents the Q-value + and takes the continuous action as input. It is concatenated with the state + and then fed to the network which outputs a single value: Q(s, a). + For more recent algorithms like SAC/TD3, multiple networks + are created to give different estimates. + + By default, it creates two critic networks used to reduce overestimation + thanks to clipped Q-learning (cf TD3 paper). + + :param observation_space: (gym.spaces.Space) Obervation space + :param action_space: (gym.spaces.Space) Action space + :param net_arch: ([int]) Network architecture + :param features_extractor: (nn.Module) Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param features_dim: (int) Number of features + :param activation_fn: (Type[nn.Module]) Activation function + :param normalize_images: (bool) Whether to normalize images or not, + dividing by 255.0 (True by default) + :param device: (Union[th.device, str]) Device on which the code should run. + :param n_critics: (int) Number of critic networks to create. + """ + + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + device: Union[th.device, str] = 'auto', + n_critics: int = 2): + super().__init__(observation_space, action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + device=device) + + action_dim = get_action_dim(self.action_space) + + self.n_critics = n_critics + self.q_networks = [] + for idx in range(n_critics): + q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) + q_net = nn.Sequential(*q_net) + self.add_module(f'qf{idx}', q_net) + self.q_networks.append(q_net) + + def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + # Learn the features extractor using the policy loss only + with th.no_grad(): + features = self.extract_features(obs) + qvalue_input = th.cat([features, actions], dim=1) + return tuple(q_net(qvalue_input) for q_net in self.q_networks) + + def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: + """ + Only predict the Q-value using the first network. + This allows to reduce computation when all the estimates are not needed + (e.g. when updating the policy in TD3). + """ + with th.no_grad(): + features = self.extract_features(obs) + return self.q_networks[0](th.cat([features, actions], dim=1)) + + def create_sde_features_extractor(features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]: diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 335bbd4..e409293 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -5,7 +5,7 @@ import torch as th import torch.nn as nn from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor +from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor, ContinuousCritic from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution @@ -179,54 +179,6 @@ class Actor(BasePolicy): return self.forward(observation, deterministic) -class Critic(BasePolicy): - """ - Critic network (q-value function) for SAC. - - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features - (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param normalize_images: (bool) Whether to normalize images or not, - dividing by 255.0 (True by default) - :param device: (Union[th.device, str]) Device on which the code should run. - """ - - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - net_arch: List[int], - features_extractor: nn.Module, - features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, - normalize_images: bool = True, - device: Union[th.device, str] = 'auto'): - super(Critic, self).__init__(observation_space, action_space, - features_extractor=features_extractor, - normalize_images=normalize_images, - device=device) - - action_dim = get_action_dim(self.action_space) - - q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - self.q1_net = nn.Sequential(*q1_net) - - q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - self.q2_net = nn.Sequential(*q2_net) - - self.q_networks = [self.q1_net, self.q2_net] - - def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]: - # Learn the features extractor using the policy loss only - # this is much faster - with th.no_grad(): - features = self.extract_features(obs) - qvalue_input = th.cat([features, action], dim=1) - return [q_net(qvalue_input) for q_net in self.q_networks] - - class SACPolicy(BasePolicy): """ Policy class (with both actor and critic) for SAC. @@ -255,6 +207,7 @@ class SACPolicy(BasePolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_critics: (int) Number of critic networks to create. """ def __init__(self, observation_space: gym.spaces.Space, @@ -272,7 +225,8 @@ class SACPolicy(BasePolicy): features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2): super(SACPolicy, self).__init__(observation_space, action_space, device, features_extractor_class, @@ -313,6 +267,9 @@ class SACPolicy(BasePolicy): 'clip_mean': clip_mean } self.actor_kwargs.update(sde_kwargs) + self.critic_kwargs = self.net_args.copy() + self.critic_kwargs.update({'n_critics': n_critics}) + self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None @@ -345,6 +302,7 @@ class SACPolicy(BasePolicy): sde_net_arch=self.actor_kwargs['sde_net_arch'], use_expln=self.actor_kwargs['use_expln'], clip_mean=self.actor_kwargs['clip_mean'], + n_critics=self.critic_kwargs['n_critics'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, @@ -364,8 +322,8 @@ class SACPolicy(BasePolicy): def make_actor(self) -> Actor: return Actor(**self.actor_kwargs).to(self.device) - def make_critic(self) -> Critic: - return Critic(**self.net_args).to(self.device) + def make_critic(self) -> ContinuousCritic: + return ContinuousCritic(**self.critic_kwargs).to(self.device) def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: return self._predict(obs, deterministic=deterministic) @@ -403,6 +361,7 @@ class CnnPolicy(SACPolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_critics: (int) Number of critic networks to create. """ def __init__(self, observation_space: gym.spaces.Space, @@ -420,7 +379,8 @@ class CnnPolicy(SACPolicy): features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2): super(CnnPolicy, self).__init__(observation_space, action_space, lr_schedule, @@ -436,7 +396,8 @@ class CnnPolicy(SACPolicy): features_extractor_kwargs, normalize_images, optimizer_class, - optimizer_kwargs) + optimizer_kwargs, + n_critics) register_policy("MlpPolicy", MlpPolicy) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 203abc4..04e20fa 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -118,7 +118,7 @@ class SAC(OffPolicyAlgorithm): def _setup_model(self) -> None: super(SAC, self)._setup_model() self._create_aliases() - + assert self.critic.n_critics == 2, "SAC only supports `n_critics=2` for now" # Target entropy is used when learning the entropy coefficient if self.target_entropy == 'auto': # automatically set target entropy if needed diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 7b863ad..325640f 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -1,11 +1,11 @@ -from typing import Optional, List, Tuple, Callable, Union, Type, Any, Dict +from typing import Optional, List, Callable, Union, Type, Any, Dict import gym import torch as th import torch.nn as nn from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.policies import BasePolicy, register_policy, ContinuousCritic from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor @@ -71,57 +71,6 @@ class Actor(BasePolicy): return self.forward(observation, deterministic=deterministic) -class Critic(BasePolicy): - """ - Critic network for TD3, - in fact it represents the action-state value function (Q-value function) - - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features - (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param normalize_images: (bool) Whether to normalize images or not, - dividing by 255.0 (True by default) - :param device: (Union[th.device, str]) Device on which the code should run. - """ - - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - net_arch: List[int], - features_extractor: nn.Module, - features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, - normalize_images: bool = True, - device: Union[th.device, str] = 'auto'): - super(Critic, self).__init__(observation_space, action_space, - features_extractor=features_extractor, - normalize_images=normalize_images, - device=device) - - action_dim = get_action_dim(self.action_space) - - q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - self.q1_net = nn.Sequential(*q1_net) - - q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - self.q2_net = nn.Sequential(*q2_net) - - def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: - # Learn the features extractor using the policy loss only - with th.no_grad(): - features = self.extract_features(obs) - qvalue_input = th.cat([features, actions], dim=1) - return self.q1_net(qvalue_input), self.q2_net(qvalue_input) - - def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: - with th.no_grad(): - features = self.extract_features(obs) - return self.q1_net(th.cat([features, actions], dim=1)) - - class TD3Policy(BasePolicy): """ Policy class (with both actor and critic) for TD3. @@ -141,6 +90,7 @@ class TD3Policy(BasePolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_critics: (int) Number of critic networks to create. """ def __init__(self, observation_space: gym.spaces.Space, @@ -153,7 +103,8 @@ class TD3Policy(BasePolicy): features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2): super(TD3Policy, self).__init__(observation_space, action_space, device, features_extractor_class, @@ -185,6 +136,8 @@ class TD3Policy(BasePolicy): 'normalize_images': normalize_images, 'device': device } + self.critic_kwargs = self.net_args.copy() + self.critic_kwargs.update({'n_critics': n_critics}) self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None @@ -208,6 +161,7 @@ class TD3Policy(BasePolicy): data.update(dict( net_arch=self.net_args['net_arch'], activation_fn=self.net_args['activation_fn'], + n_critics=self.critic_kwargs['n_critics'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, @@ -219,8 +173,8 @@ class TD3Policy(BasePolicy): def make_actor(self) -> Actor: return Actor(**self.net_args).to(self.device) - def make_critic(self) -> Critic: - return Critic(**self.net_args).to(self.device) + def make_critic(self) -> ContinuousCritic: + return ContinuousCritic(**self.critic_kwargs).to(self.device) def forward(self, observation: th.Tensor, deterministic: bool = False): return self._predict(observation, deterministic=deterministic) @@ -251,6 +205,7 @@ class CnnPolicy(TD3Policy): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_critics: (int) Number of critic networks to create. """ def __init__(self, observation_space: gym.spaces.Space, @@ -263,7 +218,8 @@ class CnnPolicy(TD3Policy): features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2): super(CnnPolicy, self).__init__(observation_space, action_space, lr_schedule, @@ -274,7 +230,8 @@ class CnnPolicy(TD3Policy): features_extractor_kwargs, normalize_images, optimizer_class, - optimizer_kwargs) + optimizer_kwargs, + n_critics) register_policy("MlpPolicy", MlpPolicy) diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 459e240..13bcc98 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -96,6 +96,7 @@ class TD3(OffPolicyAlgorithm): def _setup_model(self) -> None: super(TD3, self)._setup_model() self._create_aliases() + assert self.critic.n_critics == 2, "TD3 only supports `n_critics=2` for now" def _create_aliases(self) -> None: self.actor = self.policy.actor diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 8db4718..8369211 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.8.0a2 +0.8.0a3