Merge branch 'base-class-review' of github.com:DLR-RM/stable-baselines3 into base-class-review

This commit is contained in:
Adam Gleave 2020-07-07 18:57:35 -07:00
commit bf73f01ee9
8 changed files with 114 additions and 118 deletions

View file

@ -17,6 +17,7 @@ notebooks:
- `Monitor Training and Plotting`_
- `Atari Games`_
- `RL Baselines zoo`_
- `PyBullet`_
.. - `Hindsight Experience Replay`_
@ -27,6 +28,7 @@ notebooks:
.. _Atari Games: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
.. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
.. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb
.. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
.. |colab| image:: ../_static/img/colab.svg
@ -291,7 +293,7 @@ PyBullet: Normalizing input features
Normalizing input features may be essential to successful training of an RL agent
(by default, images are scaled but not other types of input),
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`_ environments. For that, a wrapper exists and
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
will compute a running average and standard deviation of input features (it can do the same for rewards).
@ -300,6 +302,10 @@ will compute a running average and standard deviation of input features (it can
you need to install pybullet with ``pip install pybullet``
.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
.. code-block:: python
import gym

View file

@ -3,13 +3,15 @@
Changelog
==========
Pre-Release 0.8.0a2 (WIP)
Pre-Release 0.8.0a3 (WIP)
------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``AtariWrapper`` and other Atari wrappers were updated to match SB2 ones
- ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi)
- Refactored ``Critic`` class for ``TD3`` and ``SAC``, it is now called ``ContinuousCritic``
and has an additional parameter ``n_critics``
New Features:
^^^^^^^^^^^^^
@ -40,6 +42,7 @@ Documentation:
- Updated notebook links
- Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake)
- Added Unity reacher to the projects page (@koulakis)
- Added PyBullet colab notebook
@ -342,4 +345,4 @@ And all the contributors:
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis
@tirafesi @blurLake @koulakis

View file

@ -10,7 +10,7 @@ import torch as th
import torch.nn as nn
import numpy as np
from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space
from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space, get_action_dim
from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp,
NatureCNN, MlpExtractor)
from stable_baselines3.common.utils import get_device, is_vectorized_observation
@ -644,6 +644,74 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
optimizer_kwargs)
class ContinuousCritic(BaseModel):
"""
Critic network(s) for DDPG/SAC/TD3.
It represents the action-state value function (Q-value function).
Compared to A2C/PPO critics, this one represents the Q-value
and takes the continuous action as input. It is concatenated with the state
and then fed to the network which outputs a single value: Q(s, a).
For more recent algorithms like SAC/TD3, multiple networks
are created to give different estimates.
By default, it creates two critic networks used to reduce overestimation
thanks to clipped Q-learning (cf TD3 paper).
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param features_extractor: (nn.Module) Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features
:param activation_fn: (Type[nn.Module]) Activation function
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
:param n_critics: (int) Number of critic networks to create.
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
device: Union[th.device, str] = 'auto',
n_critics: int = 2):
super().__init__(observation_space, action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
device=device)
action_dim = get_action_dim(self.action_space)
self.n_critics = n_critics
self.q_networks = []
for idx in range(n_critics):
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net)
self.add_module(f'qf{idx}', q_net)
self.q_networks.append(q_net)
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
with th.no_grad():
features = self.extract_features(obs)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
"""
Only predict the Q-value using the first network.
This allows to reduce computation when all the estimates are not needed
(e.g. when updating the policy in TD3).
"""
with th.no_grad():
features = self.extract_features(obs)
return self.q_networks[0](th.cat([features, actions], dim=1))
def create_sde_features_extractor(features_dim: int,
sde_net_arch: List[int],
activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]:

View file

@ -5,7 +5,7 @@ import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy, create_sde_features_extractor
from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor, ContinuousCritic
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
@ -179,54 +179,6 @@ class Actor(BasePolicy):
return self.forward(observation, deterministic)
class Critic(BaseModel):
"""
Critic network (q-value function) for SAC.
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param features_extractor: (nn.Module) Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features
:param activation_fn: (Type[nn.Module]) Activation function
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
device: Union[th.device, str] = 'auto'):
super(Critic, self).__init__(observation_space, action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
device=device)
action_dim = get_action_dim(self.action_space)
q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
self.q1_net = nn.Sequential(*q1_net)
q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
self.q2_net = nn.Sequential(*q2_net)
self.q_networks = [self.q1_net, self.q2_net]
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
# Learn the features extractor using the policy loss only
# this is much faster
with th.no_grad():
features = self.extract_features(obs)
qvalue_input = th.cat([features, action], dim=1)
return [q_net(qvalue_input) for q_net in self.q_networks]
class SACPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for SAC.
@ -255,6 +207,7 @@ class SACPolicy(BasePolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: (int) Number of critic networks to create.
"""
def __init__(self, observation_space: gym.spaces.Space,
@ -272,7 +225,8 @@ class SACPolicy(BasePolicy):
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2):
super(SACPolicy, self).__init__(observation_space, action_space,
device,
features_extractor_class,
@ -313,6 +267,9 @@ class SACPolicy(BasePolicy):
'clip_mean': clip_mean
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update({'n_critics': n_critics})
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
@ -345,6 +302,7 @@ class SACPolicy(BasePolicy):
sde_net_arch=self.actor_kwargs['sde_net_arch'],
use_expln=self.actor_kwargs['use_expln'],
clip_mean=self.actor_kwargs['clip_mean'],
n_critics=self.critic_kwargs['n_critics'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
@ -364,8 +322,8 @@ class SACPolicy(BasePolicy):
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)
def make_critic(self) -> Critic:
return Critic(**self.net_args).to(self.device)
def make_critic(self) -> ContinuousCritic:
return ContinuousCritic(**self.critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
@ -403,6 +361,7 @@ class CnnPolicy(SACPolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: (int) Number of critic networks to create.
"""
def __init__(self, observation_space: gym.spaces.Space,
@ -420,7 +379,8 @@ class CnnPolicy(SACPolicy):
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2):
super(CnnPolicy, self).__init__(observation_space,
action_space,
lr_schedule,
@ -436,7 +396,8 @@ class CnnPolicy(SACPolicy):
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs)
optimizer_kwargs,
n_critics)
register_policy("MlpPolicy", MlpPolicy)

View file

@ -118,7 +118,7 @@ class SAC(OffPolicyAlgorithm):
def _setup_model(self) -> None:
super(SAC, self)._setup_model()
self._create_aliases()
assert self.critic.n_critics == 2, "SAC only supports `n_critics=2` for now"
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == 'auto':
# automatically set target entropy if needed

View file

@ -1,11 +1,11 @@
from typing import Optional, List, Tuple, Callable, Union, Type, Any, Dict
from typing import Optional, List, Callable, Union, Type, Any, Dict
import gym
import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
from stable_baselines3.common.policies import BasePolicy, register_policy, ContinuousCritic
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
@ -71,57 +71,6 @@ class Actor(BasePolicy):
return self.forward(observation, deterministic=deterministic)
class Critic(BaseModel):
"""
Critic network for TD3,
in fact it represents the action-state value function (Q-value function)
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param features_extractor: (nn.Module) Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features
:param activation_fn: (Type[nn.Module]) Activation function
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
device: Union[th.device, str] = 'auto'):
super(Critic, self).__init__(observation_space, action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
device=device)
action_dim = get_action_dim(self.action_space)
q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
self.q1_net = nn.Sequential(*q1_net)
q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
self.q2_net = nn.Sequential(*q2_net)
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
# Learn the features extractor using the policy loss only
with th.no_grad():
features = self.extract_features(obs)
qvalue_input = th.cat([features, actions], dim=1)
return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
with th.no_grad():
features = self.extract_features(obs)
return self.q1_net(th.cat([features, actions], dim=1))
class TD3Policy(BasePolicy):
"""
Policy class (with both actor and critic) for TD3.
@ -141,6 +90,7 @@ class TD3Policy(BasePolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: (int) Number of critic networks to create.
"""
def __init__(self, observation_space: gym.spaces.Space,
@ -153,7 +103,8 @@ class TD3Policy(BasePolicy):
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2):
super(TD3Policy, self).__init__(observation_space, action_space,
device,
features_extractor_class,
@ -185,6 +136,8 @@ class TD3Policy(BasePolicy):
'normalize_images': normalize_images,
'device': device
}
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update({'n_critics': n_critics})
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
@ -208,6 +161,7 @@ class TD3Policy(BasePolicy):
data.update(dict(
net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
n_critics=self.critic_kwargs['n_critics'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
@ -219,8 +173,8 @@ class TD3Policy(BasePolicy):
def make_actor(self) -> Actor:
return Actor(**self.net_args).to(self.device)
def make_critic(self) -> Critic:
return Critic(**self.net_args).to(self.device)
def make_critic(self) -> ContinuousCritic:
return ContinuousCritic(**self.critic_kwargs).to(self.device)
def forward(self, observation: th.Tensor, deterministic: bool = False):
return self._predict(observation, deterministic=deterministic)
@ -251,6 +205,7 @@ class CnnPolicy(TD3Policy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: (int) Number of critic networks to create.
"""
def __init__(self, observation_space: gym.spaces.Space,
@ -263,7 +218,8 @@ class CnnPolicy(TD3Policy):
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2):
super(CnnPolicy, self).__init__(observation_space,
action_space,
lr_schedule,
@ -274,7 +230,8 @@ class CnnPolicy(TD3Policy):
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs)
optimizer_kwargs,
n_critics)
register_policy("MlpPolicy", MlpPolicy)

View file

@ -96,6 +96,7 @@ class TD3(OffPolicyAlgorithm):
def _setup_model(self) -> None:
super(TD3, self)._setup_model()
self._create_aliases()
assert self.critic.n_critics == 2, "TD3 only supports `n_critics=2` for now"
def _create_aliases(self) -> None:
self.actor = self.policy.actor

View file

@ -1 +1 @@
0.8.0a2
0.8.0a3