From 2c924f52f51b6df0e4df0391d7be7c9f74cceb6c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 29 Sep 2020 19:41:14 +0200 Subject: [PATCH] Update docs (custom policy, type hints) (#167) * Change import * Update custom policy doc * Re-enable sphinx_autodoc_typehints * Update docker image * Attempt to fix read the doc build error * Add sphinx_autodoc_typehints to read the doc env * Fix pip version * Add full custom policy example * Fix --- .gitlab-ci.yml | 2 +- docs/conda_env.yml | 5 +- docs/conf.py | 4 +- docs/guide/custom_policy.rst | 167 +++++++++++++++++++++- docs/misc/changelog.rst | 9 ++ docs/modules/td3.rst | 2 +- setup.py | 2 +- stable_baselines3/common/__init__.py | 2 - stable_baselines3/common/callbacks.py | 12 +- stable_baselines3/common/distributions.py | 2 +- stable_baselines3/common/evaluation.py | 7 +- stable_baselines3/common/policies.py | 17 ++- stable_baselines3/common/torch_layers.py | 2 +- stable_baselines3/common/type_aliases.py | 4 +- stable_baselines3/dqn/policies.py | 2 +- stable_baselines3/sac/policies.py | 2 +- stable_baselines3/td3/policies.py | 2 +- 17 files changed, 206 insertions(+), 37 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 71826e9..4813df4 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: stablebaselines/stable-baselines3-cpu:0.9.0a1 +image: stablebaselines/stable-baselines3-cpu:0.9.0a2 type-check: script: diff --git a/docs/conda_env.yml b/docs/conda_env.yml index fcd09c9..9ea054e 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -4,13 +4,14 @@ channels: - defaults dependencies: - cpuonly=1.0=0 - - pip=20.0 + - pip=20.2 - python=3.6 - pytorch=1.5.0=py3.6_cpu_0 - pip: - - gym==0.17.2 + - gym>=0.17.2 - cloudpickle - opencv-python-headless - pandas - numpy - matplotlib + - sphinx_autodoc_typehints diff --git a/docs/conf.py b/docs/conf.py index 320834c..088f8a0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -72,7 +72,7 @@ release = __version__ # ones. extensions = [ "sphinx.ext.autodoc", - # 'sphinx_autodoc_typehints', + "sphinx_autodoc_typehints", "sphinx.ext.autosummary", "sphinx.ext.mathjax", "sphinx.ext.ifconfig", @@ -128,7 +128,7 @@ html_logo = "_static/img/logo.png" def setup(app): - app.add_stylesheet("css/baselines_theme.css") + app.add_css_file("css/baselines_theme.css") # Theme options are theme-specific and customize the look and feel of a theme diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 755baee..85ae801 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -1,11 +1,15 @@ .. _custom_policy: Custom Policy Network ---------------------- +===================== Stable Baselines3 provides policy networks for images (CnnPolicies) and other type of input features (MlpPolicies). + +Custom Policy Architecture +^^^^^^^^^^^^^^^^^^^^^^^^^^ + One way of customising the policy network architecture is to pass arguments when creating the model, using ``policy_kwargs`` parameter: @@ -41,6 +45,68 @@ You can also easily define a custom architecture for the policy (or value) netwo ``policy_kwargs`` is particularly useful when doing hyperparameter search. +Custom Feature Extractor +^^^^^^^^^^^^^^^^^^^^^^^^ + +If you want to have a custom feature extractor (e.g. custom CNN when using images), you can define class +that derives from ``BaseFeaturesExtractor`` and then pass it to the model when training. + +.. code-block:: python + + import gym + import torch as th + import torch.nn as nn + + from stable_baselines3 import PPO + from stable_baselines3.common.torch_layers import BaseFeaturesExtractor + + + class CustomCNN(BaseFeaturesExtractor): + """ + :param observation_space: (gym.Space) + :param features_dim: (int) Number of features extracted. + This corresponds to the number of unit for the last layer. + """ + + def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256): + super(CustomCNN, self).__init__(observation_space, features_dim) + # We assume CxHxW images (channels first) + # Re-ordering will be done by pre-preprocessing or wrapper + n_input_channels = observation_space.shape[0] + self.cnn = nn.Sequential( + nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), + nn.ReLU(), + nn.Flatten(), + ) + + # Compute shape by doing one forward pass + with th.no_grad(): + n_flatten = self.cnn( + th.as_tensor(observation_space.sample()[None]).float() + ).shape[1] + + self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) + + def forward(self, observations: th.Tensor) -> th.Tensor: + return self.linear(self.cnn(observations)) + + policy_kwargs = dict( + features_extractor_class=CustomCNN, + features_extractor_kwargs=dict(features_dim=128), + ) + model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1) + model.learn(1000) + + + + +On-Policy Algorithms +^^^^^^^^^^^^^^^^^^^^ + +Shared Networks +--------------- The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many of them are shared between the policy network and the value network. It is assumed to be a list with the following @@ -99,7 +165,102 @@ Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]`` action value +Advanced Example +~~~~~~~~~~~~~~~~ -If your task requires even more granular control over the policy architecture, you can redefine the policy directly. +If your task requires even more granular control over the policy/value architecture, you can redefine the policy directly: -**TODO** + +.. code-block:: python + + from typing import Callable, Dict, List, Optional, Tuple, Type, Union + + import gym + import torch as th + from torch import nn + + from stable_baselines3 import PPO + from stable_baselines3.common.policies import ActorCriticPolicy + + + class CustomNetwork(nn.Module): + """ + Custom network for policy and value function. + It receives as input the features extracted by the feature extractor. + + :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN) + :param last_layer_dim_pi: (int) number of units for the last layer of the policy network + :param last_layer_dim_vf: (int) number of units for the last layer of the value network + """ + + def __init__( + self, + feature_dim: int, + last_layer_dim_pi: int = 64, + last_layer_dim_vf: int = 64, + ): + super(CustomNetwork, self).__init__() + + # IMPORTANT: + # Save output dimensions, used to create the distributions + self.latent_dim_pi = last_layer_dim_pi + self.latent_dim_vf = last_layer_dim_vf + + # Policy network + self.policy_net = nn.Sequential( + nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU() + ) + # Value network + self.value_net = nn.Sequential( + nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU() + ) + + def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + """ + :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network. + If all layers are shared, then ``latent_policy == latent_value`` + """ + return self.policy_net(features), self.value_net(features) + + + class CustomActorCriticPolicy(ActorCriticPolicy): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable[[float], float], + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + *args, + **kwargs, + ): + + super(CustomActorCriticPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + # Pass remaining arguments to base class + *args, + **kwargs, + ) + # Disable orthogonal initialization + self.ortho_init = False + + def _build_mlp_extractor(self) -> None: + self.mlp_extractor = CustomNetwork(self.features_dim) + + + model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1) + model.learn(5000) + + + + +.. TODO (see https://github.com/DLR-RM/stable-baselines3/issues/113) +.. Off-Policy Algorithms +.. ^^^^^^^^^^^^^^^^^^^^^ +.. +.. If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG`` or ``TD3``, +.. you can easily redefine the actor class for instance. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 74abf17..1112255 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -12,6 +12,13 @@ Breaking Changes: - Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and ``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params`` - Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity +- ``make_atari_env``, ``make_vec_env`` and ``set_random_seed`` must be imported with (and not directly from ``stable_baselines3.common``): + +.. code-block:: python + + from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env + from stable_baselines3.common.utils import set_random_seed + New Features: ^^^^^^^^^^^^^ @@ -47,6 +54,8 @@ Others: Documentation: ^^^^^^^^^^^^^^ - Added ``StopTrainingOnMaxEpisodes`` details and example (@xicocaio) +- Updated custom policy section (added custom feature extractor example) +- Re-enable ``sphinx_autodoc_typehints`` diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index 8902453..912fc1b 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -64,7 +64,7 @@ Example from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise env = gym.make('Pendulum-v0') - + # The noise objects for TD3 n_actions = env.action_space.shape[-1] action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) diff --git a/setup.py b/setup.py index 8401c7c..72146ad 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,7 @@ setup( # For spelling "sphinxcontrib.spelling", # Type hints support - # 'sphinx-autodoc-typehints' + "sphinx-autodoc-typehints", ], "extra": [ # For render diff --git a/stable_baselines3/common/__init__.py b/stable_baselines3/common/__init__.py index 275e3ad..e69de29 100644 --- a/stable_baselines3/common/__init__.py +++ b/stable_baselines3/common/__init__.py @@ -1,2 +0,0 @@ -from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env -from stable_baselines3.common.utils import set_random_seed diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index fe26199..d814fa6 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -1,5 +1,4 @@ import os -import typing import warnings from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union @@ -7,13 +6,10 @@ from typing import Any, Dict, List, Optional, Union import gym import numpy as np -from stable_baselines3.common import logger +from stable_baselines3.common import base_class, logger # pytype: disable=pyi-error from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization -if typing.TYPE_CHECKING: - from stable_baselines3.common.base_class import BaseAlgorithm # pytype: disable=pyi-error - class BaseCallback(ABC): """ @@ -25,7 +21,7 @@ class BaseCallback(ABC): def __init__(self, verbose: int = 0): super(BaseCallback, self).__init__() # The RL model - self.model = None # type: Optional[BaseAlgorithm] + self.model = None # type: Optional[base_class.BaseAlgorithm] # An alias for self.model.get_env(), the environment used for training self.training_env = None # type: Union[gym.Env, VecEnv, None] # Number of time the callback was called @@ -41,7 +37,7 @@ class BaseCallback(ABC): self.parent = None # type: Optional[BaseCallback] # Type hint as string to avoid circular import - def init_callback(self, model: "BaseAlgorithm") -> None: + def init_callback(self, model: "base_class.BaseAlgorithm") -> None: """ Initialize the callback by saving references to the RL model and the training environment for convenience. @@ -137,7 +133,7 @@ class EventCallback(BaseCallback): if callback is not None: self.callback.parent = self - def init_callback(self, model: "BaseAlgorithm") -> None: + def init_callback(self, model: "base_class.BaseAlgorithm") -> None: super(EventCallback, self).init_callback(model) if self.callback is not None: self.callback.init_callback(self.model) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index f46691d..4470880 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple import gym import torch as th from gym import spaces -from torch import nn as nn +from torch import nn from torch.distributions import Bernoulli, Categorical, Normal from stable_baselines3.common.preprocessing import get_action_dim diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index 0822c1c..327300d 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -1,17 +1,14 @@ -import typing from typing import Callable, List, Optional, Tuple, Union import gym import numpy as np +from stable_baselines3.common import base_class from stable_baselines3.common.vec_env import VecEnv -if typing.TYPE_CHECKING: - from stable_baselines3.common.base_class import BaseAlgorithm - def evaluate_policy( - model: "BaseAlgorithm", + model: "base_class.BaseAlgorithm", env: Union[gym.Env, VecEnv], n_eval_episodes: int = 10, deterministic: bool = True, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 0478342..8b5357c 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import gym import numpy as np import torch as th -from torch import nn as nn +from torch import nn from stable_baselines3.common.distributions import ( BernoulliDistribution, @@ -439,6 +439,16 @@ class ActorCriticPolicy(BasePolicy): assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + # Note: If net_arch is None and some features extractor is used, + # net_arch here is an empty list and mlp_extractor does not + # really contain any layers (acts like an identity module). + self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn) + def _build(self, lr_schedule: Callable[[float], float]) -> None: """ Create the networks and the optimizer. @@ -446,10 +456,7 @@ class ActorCriticPolicy(BasePolicy): :param lr_schedule: (Callable) Learning rate schedule lr_schedule(1) is the initial learning rate """ - # Note: If net_arch is None and some features extractor is used, - # net_arch here is an empty list and mlp_extractor does not - # really contain any layers (acts like an identity module). - self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn) + self._build_mlp_extractor() latent_dim_pi = self.mlp_extractor.latent_dim_pi diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 9429a86..07449e9 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -3,7 +3,7 @@ from typing import Dict, List, Tuple, Type, Union import gym import torch as th -from torch import nn as nn +from torch import nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space from stable_baselines3.common.utils import get_device diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index e16f435..0d95d00 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -6,7 +6,7 @@ import gym import numpy as np import torch as th -from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common import callbacks from stable_baselines3.common.vec_env import VecEnv GymEnv = Union[gym.Env, VecEnv] @@ -14,7 +14,7 @@ GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] GymStepReturn = Tuple[GymObs, float, bool, Dict] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] -MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback] +MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] class RolloutBufferSamples(NamedTuple): diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index ebbcd34..f8ec737 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Type import gym import torch as th -from torch import nn as nn +from torch import nn from stable_baselines3.common.policies import BasePolicy, register_policy from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 843ee06..6dd98e1 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type import gym import torch as th -from torch import nn as nn +from torch import nn from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 2a60043..c5cdd8e 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Type import gym import torch as th -from torch import nn as nn +from torch import nn from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy from stable_baselines3.common.preprocessing import get_action_dim