mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Update docs (custom policy, type hints) (#167)
* Change import * Update custom policy doc * Re-enable sphinx_autodoc_typehints * Update docker image * Attempt to fix read the doc build error * Add sphinx_autodoc_typehints to read the doc env * Fix pip version * Add full custom policy example * Fix
This commit is contained in:
parent
8b16324ba7
commit
2c924f52f5
17 changed files with 206 additions and 37 deletions
|
|
@ -1,4 +1,4 @@
|
|||
image: stablebaselines/stable-baselines3-cpu:0.9.0a1
|
||||
image: stablebaselines/stable-baselines3-cpu:0.9.0a2
|
||||
|
||||
type-check:
|
||||
script:
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@ channels:
|
|||
- defaults
|
||||
dependencies:
|
||||
- cpuonly=1.0=0
|
||||
- pip=20.0
|
||||
- pip=20.2
|
||||
- python=3.6
|
||||
- pytorch=1.5.0=py3.6_cpu_0
|
||||
- pip:
|
||||
- gym==0.17.2
|
||||
- gym>=0.17.2
|
||||
- cloudpickle
|
||||
- opencv-python-headless
|
||||
- pandas
|
||||
- numpy
|
||||
- matplotlib
|
||||
- sphinx_autodoc_typehints
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ release = __version__
|
|||
# ones.
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
# 'sphinx_autodoc_typehints',
|
||||
"sphinx_autodoc_typehints",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.mathjax",
|
||||
"sphinx.ext.ifconfig",
|
||||
|
|
@ -128,7 +128,7 @@ html_logo = "_static/img/logo.png"
|
|||
|
||||
|
||||
def setup(app):
|
||||
app.add_stylesheet("css/baselines_theme.css")
|
||||
app.add_css_file("css/baselines_theme.css")
|
||||
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
.. _custom_policy:
|
||||
|
||||
Custom Policy Network
|
||||
---------------------
|
||||
=====================
|
||||
|
||||
Stable Baselines3 provides policy networks for images (CnnPolicies)
|
||||
and other type of input features (MlpPolicies).
|
||||
|
||||
|
||||
Custom Policy Architecture
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
One way of customising the policy network architecture is to pass arguments when creating the model,
|
||||
using ``policy_kwargs`` parameter:
|
||||
|
||||
|
|
@ -41,6 +45,68 @@ You can also easily define a custom architecture for the policy (or value) netwo
|
|||
``policy_kwargs`` is particularly useful when doing hyperparameter search.
|
||||
|
||||
|
||||
Custom Feature Extractor
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
If you want to have a custom feature extractor (e.g. custom CNN when using images), you can define class
|
||||
that derives from ``BaseFeaturesExtractor`` and then pass it to the model when training.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
|
||||
class CustomCNN(BaseFeaturesExtractor):
|
||||
"""
|
||||
:param observation_space: (gym.Space)
|
||||
:param features_dim: (int) Number of features extracted.
|
||||
This corresponds to the number of unit for the last layer.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
|
||||
super(CustomCNN, self).__init__(observation_space, features_dim)
|
||||
# We assume CxHxW images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
n_input_channels = observation_space.shape[0]
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
)
|
||||
|
||||
# Compute shape by doing one forward pass
|
||||
with th.no_grad():
|
||||
n_flatten = self.cnn(
|
||||
th.as_tensor(observation_space.sample()[None]).float()
|
||||
).shape[1]
|
||||
|
||||
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.linear(self.cnn(observations))
|
||||
|
||||
policy_kwargs = dict(
|
||||
features_extractor_class=CustomCNN,
|
||||
features_extractor_kwargs=dict(features_dim=128),
|
||||
)
|
||||
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)
|
||||
model.learn(1000)
|
||||
|
||||
|
||||
|
||||
|
||||
On-Policy Algorithms
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Shared Networks
|
||||
---------------
|
||||
|
||||
The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many
|
||||
of them are shared between the policy network and the value network. It is assumed to be a list with the following
|
||||
|
|
@ -99,7 +165,102 @@ Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]``
|
|||
action value
|
||||
|
||||
|
||||
Advanced Example
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
If your task requires even more granular control over the policy architecture, you can redefine the policy directly.
|
||||
If your task requires even more granular control over the policy/value architecture, you can redefine the policy directly:
|
||||
|
||||
**TODO**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
|
||||
|
||||
class CustomNetwork(nn.Module):
|
||||
"""
|
||||
Custom network for policy and value function.
|
||||
It receives as input the features extracted by the feature extractor.
|
||||
|
||||
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
|
||||
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
|
||||
:param last_layer_dim_vf: (int) number of units for the last layer of the value network
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_dim: int,
|
||||
last_layer_dim_pi: int = 64,
|
||||
last_layer_dim_vf: int = 64,
|
||||
):
|
||||
super(CustomNetwork, self).__init__()
|
||||
|
||||
# IMPORTANT:
|
||||
# Save output dimensions, used to create the distributions
|
||||
self.latent_dim_pi = last_layer_dim_pi
|
||||
self.latent_dim_vf = last_layer_dim_vf
|
||||
|
||||
# Policy network
|
||||
self.policy_net = nn.Sequential(
|
||||
nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
|
||||
)
|
||||
# Value network
|
||||
self.value_net = nn.Sequential(
|
||||
nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
|
||||
If all layers are shared, then ``latent_policy == latent_value``
|
||||
"""
|
||||
return self.policy_net(features), self.value_net(features)
|
||||
|
||||
|
||||
class CustomActorCriticPolicy(ActorCriticPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable[[float], float],
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super(CustomActorCriticPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
# Pass remaining arguments to base class
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
# Disable orthogonal initialization
|
||||
self.ortho_init = False
|
||||
|
||||
def _build_mlp_extractor(self) -> None:
|
||||
self.mlp_extractor = CustomNetwork(self.features_dim)
|
||||
|
||||
|
||||
model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
|
||||
model.learn(5000)
|
||||
|
||||
|
||||
|
||||
|
||||
.. TODO (see https://github.com/DLR-RM/stable-baselines3/issues/113)
|
||||
.. Off-Policy Algorithms
|
||||
.. ^^^^^^^^^^^^^^^^^^^^^
|
||||
..
|
||||
.. If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG`` or ``TD3``,
|
||||
.. you can easily redefine the actor class for instance.
|
||||
|
|
|
|||
|
|
@ -12,6 +12,13 @@ Breaking Changes:
|
|||
- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and
|
||||
``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params``
|
||||
- Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity
|
||||
- ``make_atari_env``, ``make_vec_env`` and ``set_random_seed`` must be imported with (and not directly from ``stable_baselines3.common``):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -47,6 +54,8 @@ Others:
|
|||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Added ``StopTrainingOnMaxEpisodes`` details and example (@xicocaio)
|
||||
- Updated custom policy section (added custom feature extractor example)
|
||||
- Re-enable ``sphinx_autodoc_typehints``
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ Example
|
|||
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||
|
||||
env = gym.make('Pendulum-v0')
|
||||
|
||||
|
||||
# The noise objects for TD3
|
||||
n_actions = env.action_space.shape[-1]
|
||||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -103,7 +103,7 @@ setup(
|
|||
# For spelling
|
||||
"sphinxcontrib.spelling",
|
||||
# Type hints support
|
||||
# 'sphinx-autodoc-typehints'
|
||||
"sphinx-autodoc-typehints",
|
||||
],
|
||||
"extra": [
|
||||
# For render
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
|
@ -7,13 +6,10 @@ from typing import Any, Dict, List, Optional, Union
|
|||
import gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common import base_class, logger # pytype: disable=pyi-error
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm # pytype: disable=pyi-error
|
||||
|
||||
|
||||
class BaseCallback(ABC):
|
||||
"""
|
||||
|
|
@ -25,7 +21,7 @@ class BaseCallback(ABC):
|
|||
def __init__(self, verbose: int = 0):
|
||||
super(BaseCallback, self).__init__()
|
||||
# The RL model
|
||||
self.model = None # type: Optional[BaseAlgorithm]
|
||||
self.model = None # type: Optional[base_class.BaseAlgorithm]
|
||||
# An alias for self.model.get_env(), the environment used for training
|
||||
self.training_env = None # type: Union[gym.Env, VecEnv, None]
|
||||
# Number of time the callback was called
|
||||
|
|
@ -41,7 +37,7 @@ class BaseCallback(ABC):
|
|||
self.parent = None # type: Optional[BaseCallback]
|
||||
|
||||
# Type hint as string to avoid circular import
|
||||
def init_callback(self, model: "BaseAlgorithm") -> None:
|
||||
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
|
||||
"""
|
||||
Initialize the callback by saving references to the
|
||||
RL model and the training environment for convenience.
|
||||
|
|
@ -137,7 +133,7 @@ class EventCallback(BaseCallback):
|
|||
if callback is not None:
|
||||
self.callback.parent = self
|
||||
|
||||
def init_callback(self, model: "BaseAlgorithm") -> None:
|
||||
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
|
||||
super(EventCallback, self).init_callback(model)
|
||||
if self.callback is not None:
|
||||
self.callback.init_callback(self.model)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
import gym
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
from torch.distributions import Bernoulli, Categorical, Normal
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
import typing
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import base_class
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
|
||||
|
||||
def evaluate_policy(
|
||||
model: "BaseAlgorithm",
|
||||
model: "base_class.BaseAlgorithm",
|
||||
env: Union[gym.Env, VecEnv],
|
||||
n_eval_episodes: int = 10,
|
||||
deterministic: bool = True,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
|||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.distributions import (
|
||||
BernoulliDistribution,
|
||||
|
|
@ -439,6 +439,16 @@ class ActorCriticPolicy(BasePolicy):
|
|||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build_mlp_extractor(self) -> None:
|
||||
"""
|
||||
Create the policy and value networks.
|
||||
Part of the layers can be shared.
|
||||
"""
|
||||
# Note: If net_arch is None and some features extractor is used,
|
||||
# net_arch here is an empty list and mlp_extractor does not
|
||||
# really contain any layers (acts like an identity module).
|
||||
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn)
|
||||
|
||||
def _build(self, lr_schedule: Callable[[float], float]) -> None:
|
||||
"""
|
||||
Create the networks and the optimizer.
|
||||
|
|
@ -446,10 +456,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
:param lr_schedule: (Callable) Learning rate schedule
|
||||
lr_schedule(1) is the initial learning rate
|
||||
"""
|
||||
# Note: If net_arch is None and some features extractor is used,
|
||||
# net_arch here is an empty list and mlp_extractor does not
|
||||
# really contain any layers (acts like an identity module).
|
||||
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn)
|
||||
self._build_mlp_extractor()
|
||||
|
||||
latent_dim_pi = self.mlp_extractor.latent_dim_pi
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Dict, List, Tuple, Type, Union
|
|||
|
||||
import gym
|
||||
import torch as th
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
|
||||
from stable_baselines3.common.utils import get_device
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import gym
|
|||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common import callbacks
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
GymEnv = Union[gym.Env, VecEnv]
|
||||
|
|
@ -14,7 +14,7 @@ GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
|
|||
GymStepReturn = Tuple[GymObs, float, bool, Dict]
|
||||
TensorDict = Dict[str, th.Tensor]
|
||||
OptimizerStateDict = Dict[str, Any]
|
||||
MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback]
|
||||
MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]
|
||||
|
||||
|
||||
class RolloutBufferSamples(NamedTuple):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Type
|
|||
|
||||
import gym
|
||||
import torch as th
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
|||
|
||||
import gym
|
||||
import torch as th
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Type
|
|||
|
||||
import gym
|
||||
import torch as th
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
|
|
|||
Loading…
Reference in a new issue