Update docs (custom policy, type hints) (#167)

* Change import

* Update custom policy doc

* Re-enable sphinx_autodoc_typehints

* Update docker image

* Attempt to fix read the doc build error

* Add sphinx_autodoc_typehints to read the doc env

* Fix pip version

* Add full custom policy example

* Fix
This commit is contained in:
Antonin RAFFIN 2020-09-29 19:41:14 +02:00 committed by GitHub
parent 8b16324ba7
commit 2c924f52f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 206 additions and 37 deletions

View file

@ -1,4 +1,4 @@
image: stablebaselines/stable-baselines3-cpu:0.9.0a1
image: stablebaselines/stable-baselines3-cpu:0.9.0a2
type-check:
script:

View file

@ -4,13 +4,14 @@ channels:
- defaults
dependencies:
- cpuonly=1.0=0
- pip=20.0
- pip=20.2
- python=3.6
- pytorch=1.5.0=py3.6_cpu_0
- pip:
- gym==0.17.2
- gym>=0.17.2
- cloudpickle
- opencv-python-headless
- pandas
- numpy
- matplotlib
- sphinx_autodoc_typehints

View file

@ -72,7 +72,7 @@ release = __version__
# ones.
extensions = [
"sphinx.ext.autodoc",
# 'sphinx_autodoc_typehints',
"sphinx_autodoc_typehints",
"sphinx.ext.autosummary",
"sphinx.ext.mathjax",
"sphinx.ext.ifconfig",
@ -128,7 +128,7 @@ html_logo = "_static/img/logo.png"
def setup(app):
app.add_stylesheet("css/baselines_theme.css")
app.add_css_file("css/baselines_theme.css")
# Theme options are theme-specific and customize the look and feel of a theme

View file

@ -1,11 +1,15 @@
.. _custom_policy:
Custom Policy Network
---------------------
=====================
Stable Baselines3 provides policy networks for images (CnnPolicies)
and other type of input features (MlpPolicies).
Custom Policy Architecture
^^^^^^^^^^^^^^^^^^^^^^^^^^
One way of customising the policy network architecture is to pass arguments when creating the model,
using ``policy_kwargs`` parameter:
@ -41,6 +45,68 @@ You can also easily define a custom architecture for the policy (or value) netwo
``policy_kwargs`` is particularly useful when doing hyperparameter search.
Custom Feature Extractor
^^^^^^^^^^^^^^^^^^^^^^^^
If you want to have a custom feature extractor (e.g. custom CNN when using images), you can define class
that derives from ``BaseFeaturesExtractor`` and then pass it to the model when training.
.. code-block:: python
import gym
import torch as th
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCNN(BaseFeaturesExtractor):
"""
:param observation_space: (gym.Space)
:param features_dim: (int) Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
super(CustomCNN, self).__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=128),
)
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)
model.learn(1000)
On-Policy Algorithms
^^^^^^^^^^^^^^^^^^^^
Shared Networks
---------------
The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many
of them are shared between the policy network and the value network. It is assumed to be a list with the following
@ -99,7 +165,102 @@ Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]``
action value
Advanced Example
~~~~~~~~~~~~~~~~
If your task requires even more granular control over the policy architecture, you can redefine the policy directly.
If your task requires even more granular control over the policy/value architecture, you can redefine the policy directly:
**TODO**
.. code-block:: python
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
from torch import nn
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
class CustomNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the feature extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
:param last_layer_dim_vf: (int) number of units for the last layer of the value network
"""
def __init__(
self,
feature_dim: int,
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super(CustomNetwork, self).__init__()
# IMPORTANT:
# Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Policy network
self.policy_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
)
# Value network
self.value_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.policy_net(features), self.value_net(features)
class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
*args,
**kwargs,
):
super(CustomActorCriticPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
# Pass remaining arguments to base class
*args,
**kwargs,
)
# Disable orthogonal initialization
self.ortho_init = False
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork(self.features_dim)
model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
model.learn(5000)
.. TODO (see https://github.com/DLR-RM/stable-baselines3/issues/113)
.. Off-Policy Algorithms
.. ^^^^^^^^^^^^^^^^^^^^^
..
.. If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG`` or ``TD3``,
.. you can easily redefine the actor class for instance.

View file

@ -12,6 +12,13 @@ Breaking Changes:
- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and
``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params``
- Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity
- ``make_atari_env``, ``make_vec_env`` and ``set_random_seed`` must be imported with (and not directly from ``stable_baselines3.common``):
.. code-block:: python
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed
New Features:
^^^^^^^^^^^^^
@ -47,6 +54,8 @@ Others:
Documentation:
^^^^^^^^^^^^^^
- Added ``StopTrainingOnMaxEpisodes`` details and example (@xicocaio)
- Updated custom policy section (added custom feature extractor example)
- Re-enable ``sphinx_autodoc_typehints``

View file

@ -64,7 +64,7 @@ Example
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make('Pendulum-v0')
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

View file

@ -103,7 +103,7 @@ setup(
# For spelling
"sphinxcontrib.spelling",
# Type hints support
# 'sphinx-autodoc-typehints'
"sphinx-autodoc-typehints",
],
"extra": [
# For render

View file

@ -1,2 +0,0 @@
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed

View file

@ -1,5 +1,4 @@
import os
import typing
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
@ -7,13 +6,10 @@ from typing import Any, Dict, List, Optional, Union
import gym
import numpy as np
from stable_baselines3.common import logger
from stable_baselines3.common import base_class, logger # pytype: disable=pyi-error
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
if typing.TYPE_CHECKING:
from stable_baselines3.common.base_class import BaseAlgorithm # pytype: disable=pyi-error
class BaseCallback(ABC):
"""
@ -25,7 +21,7 @@ class BaseCallback(ABC):
def __init__(self, verbose: int = 0):
super(BaseCallback, self).__init__()
# The RL model
self.model = None # type: Optional[BaseAlgorithm]
self.model = None # type: Optional[base_class.BaseAlgorithm]
# An alias for self.model.get_env(), the environment used for training
self.training_env = None # type: Union[gym.Env, VecEnv, None]
# Number of time the callback was called
@ -41,7 +37,7 @@ class BaseCallback(ABC):
self.parent = None # type: Optional[BaseCallback]
# Type hint as string to avoid circular import
def init_callback(self, model: "BaseAlgorithm") -> None:
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
"""
Initialize the callback by saving references to the
RL model and the training environment for convenience.
@ -137,7 +133,7 @@ class EventCallback(BaseCallback):
if callback is not None:
self.callback.parent = self
def init_callback(self, model: "BaseAlgorithm") -> None:
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
super(EventCallback, self).init_callback(model)
if self.callback is not None:
self.callback.init_callback(self.model)

View file

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple
import gym
import torch as th
from gym import spaces
from torch import nn as nn
from torch import nn
from torch.distributions import Bernoulli, Categorical, Normal
from stable_baselines3.common.preprocessing import get_action_dim

View file

@ -1,17 +1,14 @@
import typing
from typing import Callable, List, Optional, Tuple, Union
import gym
import numpy as np
from stable_baselines3.common import base_class
from stable_baselines3.common.vec_env import VecEnv
if typing.TYPE_CHECKING:
from stable_baselines3.common.base_class import BaseAlgorithm
def evaluate_policy(
model: "BaseAlgorithm",
model: "base_class.BaseAlgorithm",
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,

View file

@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from torch import nn as nn
from torch import nn
from stable_baselines3.common.distributions import (
BernoulliDistribution,
@ -439,6 +439,16 @@ class ActorCriticPolicy(BasePolicy):
assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn)
def _build(self, lr_schedule: Callable[[float], float]) -> None:
"""
Create the networks and the optimizer.
@ -446,10 +456,7 @@ class ActorCriticPolicy(BasePolicy):
:param lr_schedule: (Callable) Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn)
self._build_mlp_extractor()
latent_dim_pi = self.mlp_extractor.latent_dim_pi

View file

@ -3,7 +3,7 @@ from typing import Dict, List, Tuple, Type, Union
import gym
import torch as th
from torch import nn as nn
from torch import nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3.common.utils import get_device

View file

@ -6,7 +6,7 @@ import gym
import numpy as np
import torch as th
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import callbacks
from stable_baselines3.common.vec_env import VecEnv
GymEnv = Union[gym.Env, VecEnv]
@ -14,7 +14,7 @@ GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
GymStepReturn = Tuple[GymObs, float, bool, Dict]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback]
MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]
class RolloutBufferSamples(NamedTuple):

View file

@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Type
import gym
import torch as th
from torch import nn as nn
from torch import nn
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp

View file

@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import gym
import torch as th
from torch import nn as nn
from torch import nn
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy

View file

@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Type
import gym
import torch as th
from torch import nn as nn
from torch import nn
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
from stable_baselines3.common.preprocessing import get_action_dim