Modified ActorCriticPolicy to support non-shared features extractor (#1148)

* Modified ActorCriticPolicy to support non-shared features extractor

* Refactored features extraction with non-shared features extractor in ActorCriticPolicy and updated doc

Doc update: added 'warning' on custom policy docs that says that, if the features extractor is non-shared, it's not possible to have shared layers in the mlp_extractor

* Moved attrib share_features_extractor in class

* Updated custom policy doc for non-shared features extractor

* Updated changelog

* Made some if-statements more readable if policies.py

The if-statements are related to the shared/non-shared features extractor in ActorCritic policies

* Simplify implementation and add run test

* Keep order in module gain to keep previous results consistents

* Fix test

* Improved docstring in policies.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Added some tests

* feature extractor -> features extractor

* Fix test

* Fix env_id in test

* Make features extractor parameter explicit

* Remove duplicate

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
Alex Pasquali 2022-12-20 15:12:05 +01:00 committed by GitHub
parent 8452106734
commit 2cfcec4f50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 138 additions and 47 deletions

View file

@ -108,17 +108,19 @@ using ``policy_kwargs`` parameter:
Custom Feature Extractor
^^^^^^^^^^^^^^^^^^^^^^^^
If you want to have a custom feature extractor (e.g. custom CNN when using images), you can define class
If you want to have a custom features extractor (e.g. custom CNN when using images), you can define class
that derives from ``BaseFeaturesExtractor`` and then pass it to the model when training.
.. note::
By default the feature extractor is shared between the actor and the critic to save computation (when applicable).
However, this can be changed by defining a custom policy for on-policy algorithms
(see `issue #1066 <https://github.com/DLR-RM/stable-baselines3/issues/1066#issuecomment-1246866844>`_
for more information) or setting ``share_features_extractor=False`` in the
``policy_kwargs`` for off-policy algorithms (and when applicable).
By default the features extractor is shared between the actor and the critic to save computation (when applicable).
However, this can be changed setting ``share_features_extractor=False`` in the
``policy_kwargs`` (both for on-policy and off-policy algorithms).
.. warning::
If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``.
.. code-block:: python
@ -174,7 +176,7 @@ Multiple Inputs and Dictionary Observations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Stable Baselines3 supports handling of multiple inputs by using ``Dict`` Gym space. This can be done using
``MultiInputPolicy``, which by default uses the ``CombinedExtractor`` feature extractor to turn multiple
``MultiInputPolicy``, which by default uses the ``CombinedExtractor`` features extractor to turn multiple
inputs into a single vector, handled by the ``net_arch`` network.
By default, ``CombinedExtractor`` processes multiple inputs as follows:
@ -184,7 +186,7 @@ By default, ``CombinedExtractor`` processes multiple inputs as follows:
2. If input is not an image, flatten it (no layers).
3. Concatenate all previous vectors into one long vector and pass it to policy.
Much like above, you can define custom feature extractors. The following example assumes the environment has two keys in the
Much like above, you can define custom features extractors. The following example assumes the environment has two keys in the
observation space dictionary: "image" is a (1,H,W) image (channel first), and "vector" is a (D,) dimensional vector. We process "image" with a simple
downsampling and "vector" with a single linear layer.
@ -319,7 +321,7 @@ If your task requires even more granular control over the policy/value architect
class CustomNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the feature extractor.
It receives as input the features extracted by the features extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
@ -411,7 +413,7 @@ you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256
.. note::
Compared to their on-policy counterparts, no shared layers (other than the feature extractor)
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
between the actor and the critic are allowed (to prevent issues with target networks).

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.7.0a7 (WIP)
Release 1.7.0a8 (WIP)
--------------------------
Breaking Changes:
@ -18,6 +18,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
- Added option to have non-shared features extractor between actor and critic in on-policy algorithms (@AlexPasqua)
- Added ``with_bias`` argument to ``create_mlp``
- Added support for multidimensional ``spaces.MultiBinary`` observations
- Features extractors now properly support unnormalized image-like observations (3D tensor)
@ -40,6 +41,7 @@ Bug Fixes:
Deprecations:
^^^^^^^^^^^^^
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
Others:
^^^^^^^
@ -685,8 +687,8 @@ Bug Fixes:
- Fix model creation initializing CUDA even when `device="cpu"` is provided
- Fix ``check_env`` not checking if the env has a Dict actionspace before calling ``_check_nan`` (@wmmc88)
- Update the check for spaces unsupported by Stable Baselines 3 to include checks on the action space (@wmmc88)
- Fixed feature extractor bug for target network where the same net was shared instead
of being separate. This bug affects ``SAC``, ``DDPG`` and ``TD3`` when using ``CnnPolicy`` (or custom feature extractor)
- Fixed features extractor bug for target network where the same net was shared instead
of being separate. This bug affects ``SAC``, ``DDPG`` and ``TD3`` when using ``CnnPolicy`` (or custom features extractor)
- Fixed a bug when passing an environment when loading a saved model with a ``CnnPolicy``, the passed env was not wrapped properly
(the bug was introduced when implementing ``HER`` so it should not be present in previous versions)
@ -763,7 +765,7 @@ Others:
Documentation:
^^^^^^^^^^^^^^
- Added ``StopTrainingOnMaxEpisodes`` details and example (@xicocaio)
- Updated custom policy section (added custom feature extractor example)
- Updated custom policy section (added custom features extractor example)
- Re-enable ``sphinx_autodoc_typehints``
- Updated doc style for type hints and remove duplicated type hints
@ -801,7 +803,7 @@ Bug Fixes:
- Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang)
- Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37)
- Fixed approximate entropy calculation in PPO and A2C. (@andyshih12)
- Fixed DQN target network sharing feature extractor with the main network.
- Fixed DQN target network sharing features extractor with the main network.
- Fixed storing correct ``dones`` in on-policy algorithm rollout collection. (@andyshih12)
- Fixed number of filters in final convolutional layer in NatureCNN to match original implementation.
@ -841,7 +843,7 @@ Breaking Changes:
- ``render()`` method of ``VecEnvs`` now only accept one argument: ``mode``
- Created new file common/torch_layers.py, similar to SB refactoring
- Contains all PyTorch network layer definitions and feature extractors: ``MlpExtractor``, ``create_mlp``, ``NatureCNN``
- Contains all PyTorch network layer definitions and features extractors: ``MlpExtractor``, ``create_mlp``, ``NatureCNN``
- Renamed ``BaseRLModel`` to ``BaseAlgorithm`` (along with offpolicy and onpolicy variants)
- Moved on-policy and off-policy base algorithms to ``common/on_policy_algorithm.py`` and ``common/off_policy_algorithm.py``, respectively.

View file

@ -50,7 +50,7 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
if observation_space.shape[non_channel_idx] < 36 or observation_space.shape[1] < 36:
warnings.warn(
"The minimal resolution for an image is 36x36 for the default `CnnPolicy`. "
"You might need to use a custom feature extractor "
"You might need to use a custom features extractor "
"cf. https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html"
)

View file

@ -2,6 +2,7 @@
import collections
import copy
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
@ -117,16 +118,28 @@ class BaseModel(nn.Module):
"""Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
def extract_features(self, obs: th.Tensor) -> th.Tensor:
def extract_features(self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
:param obs:
:return:
:param obs: The observation
:param features_extractor: The features extractor to use. If it is set to None,
the features extractor of the policy is used.
:return: The features
"""
assert self.features_extractor is not None, "No features extractor was set"
if features_extractor is None:
warnings.warn(
(
"When calling extract_features(), you should explicitely pass a features_extractor as parameter. "
"This will be mandatory in Stable-Baselines v1.8.0"
),
DeprecationWarning,
)
features_extractor = features_extractor or self.features_extractor
assert features_extractor is not None, "No features extractor was set"
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return self.features_extractor(preprocessed_obs)
return features_extractor(preprocessed_obs)
def _get_constructor_parameters(self) -> Dict[str, Any]:
"""
@ -391,6 +404,7 @@ class ActorCriticPolicy(BasePolicy):
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
@ -414,6 +428,7 @@ class ActorCriticPolicy(BasePolicy):
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
@ -447,8 +462,20 @@ class ActorCriticPolicy(BasePolicy):
self.activation_fn = activation_fn
self.ortho_init = ortho_init
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.share_features_extractor = share_features_extractor
self.features_extractor = self.make_features_extractor()
self.features_dim = self.features_extractor.features_dim
if self.share_features_extractor:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.features_extractor
else:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
raise ValueError(
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
)
self.log_std_init = log_std_init
dist_kwargs = None
@ -555,6 +582,13 @@ class ActorCriticPolicy(BasePolicy):
self.action_net: 0.01,
self.value_net: 1,
}
if not self.share_features_extractor:
# Note(antonin): this is to keep SB3 results
# consistent, see GH#1148
del module_gains[self.features_extractor]
module_gains[self.pi_features_extractor] = np.sqrt(2)
module_gains[self.vf_features_extractor] = np.sqrt(2)
for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))
@ -571,7 +605,12 @@ class ActorCriticPolicy(BasePolicy):
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
@ -580,6 +619,20 @@ class ActorCriticPolicy(BasePolicy):
actions = actions.reshape((-1,) + self.action_space.shape)
return actions, values, log_prob
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
:param obs: Observation
:return: the output of the features extractor(s)
"""
if self.share_features_extractor:
return super().extract_features(obs, self.features_extractor)
else:
pi_features = super().extract_features(obs, self.pi_features_extractor)
vf_features = super().extract_features(obs, self.vf_features_extractor)
return pi_features, vf_features
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
"""
Retrieve action distribution given the latent codes.
@ -620,14 +673,19 @@ class ActorCriticPolicy(BasePolicy):
Evaluate actions according to the current policy,
given the observations.
:param obs:
:param actions:
:param obs: Observation
:param actions: Actions
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
@ -641,7 +699,7 @@ class ActorCriticPolicy(BasePolicy):
:param obs:
:return: the action distribution.
"""
features = self.extract_features(obs)
features = super().extract_features(obs, self.pi_features_extractor)
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)
@ -649,10 +707,10 @@ class ActorCriticPolicy(BasePolicy):
"""
Get the estimated values according to the current policy given the observations.
:param obs:
:param obs: Observation
:return: the estimated values.
"""
features = self.extract_features(obs)
features = super().extract_features(obs, self.vf_features_extractor)
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)
@ -680,6 +738,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
@ -703,6 +762,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
@ -721,6 +781,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
squash_output,
features_extractor_class,
features_extractor_kwargs,
share_features_extractor,
normalize_images,
optimizer_class,
optimizer_kwargs,
@ -749,7 +810,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Uses the CombinedExtractor
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
@ -773,6 +835,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
@ -791,6 +854,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
squash_output,
features_extractor_class,
features_extractor_kwargs,
share_features_extractor,
normalize_images,
optimizer_class,
optimizer_kwargs,
@ -858,7 +922,7 @@ class ContinuousCritic(BaseModel):
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs)
features = self.extract_features(obs, self.features_extractor)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
@ -869,5 +933,5 @@ class ContinuousCritic(BaseModel):
(e.g. when updating the policy in TD3).
"""
with th.no_grad():
features = self.extract_features(obs)
features = self.extract_features(obs, self.features_extractor)
return self.q_networks[0](th.cat([features, actions], dim=1))

View file

@ -146,8 +146,8 @@ def create_mlp(
class MlpExtractor(nn.Module):
"""
Constructs an MLP that receives the output from a previous feature extractor (i.e. a CNN) or directly
the observations (if no feature extractor is applied) as an input and outputs a latent representation
Constructs an MLP that receives the output from a previous features extractor (i.e. a CNN) or directly
the observations (if no features extractor is applied) as an input and outputs a latent representation
for the policy and a value network.
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
of them are shared between the policy network and the value network. It is assumed to be a list with the following
@ -251,8 +251,8 @@ class MlpExtractor(nn.Module):
class CombinedExtractor(BaseFeaturesExtractor):
"""
Combined feature extractor for Dict observation spaces.
Builds a feature extractor for each key of the space. Input from each space
Combined features extractor for Dict observation spaces.
Builds a features extractor for each key of the space. Input from each space
is fed through a separate submodule (CNN or MLP, depending on input shape),
the output features are concatenated and fed through additional MLP network ("combined").

View file

@ -62,7 +62,7 @@ class QNetwork(BasePolicy):
:param obs: Observation
:return: The estimated Q-Value for each action.
"""
return self.q_net(self.extract_features(obs))
return self.q_net(self.extract_features(obs, self.features_extractor))
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
q_values = self(observation)

View file

@ -150,7 +150,7 @@ class Actor(BasePolicy):
:return:
Mean, standard deviation and optional keyword arguments.
"""
features = self.extract_features(obs)
features = self.extract_features(obs, self.features_extractor)
latent_pi = self.latent_pi(features)
mean_actions = self.mu(latent_pi)

View file

@ -74,7 +74,7 @@ class Actor(BasePolicy):
def forward(self, obs: th.Tensor) -> th.Tensor:
# assert deterministic, 'The TD3 actor only outputs deterministic actions'
features = self.extract_features(obs)
features = self.extract_features(obs, self.features_extractor)
return self.mu(features)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:

View file

@ -1 +1 @@
1.7.0a7
1.7.0a8

View file

@ -14,15 +14,25 @@ from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNorm
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
def test_cnn(tmp_path, model_class):
@pytest.mark.parametrize("share_features_extractor", [True, False])
def test_cnn(tmp_path, model_class, share_features_extractor):
SAVE_NAME = "cnn_model.zip"
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
if model_class in {A2C, PPO}:
kwargs = dict(n_steps=64)
kwargs = dict(
n_steps=64,
policy_kwargs=dict(
share_features_extractor=share_features_extractor,
),
)
else:
# share_features_extractor is checked later for offpolicy algorithms
if share_features_extractor:
return
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
@ -139,10 +149,10 @@ def test_features_extractor_target_net(model_class, share_features_extractor):
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
if model_class == TD3:
assert id(model.policy.actor_target.features_extractor) == id(model.policy.critic_target.features_extractor)
# Actor and critic feature extractor should be the same
# Actor and critic features extractor should be the same
td3_features_extractor_check = check_td3_feature_extractor_match
else:
# Actor and critic feature extractor should differ same
# Actor and critic features extractor should differ same
td3_features_extractor_check = check_td3_feature_extractor_differ
# Check that the object differ
if model_class != DQN:

View file

@ -28,9 +28,15 @@ def test_custom_offpolicy(model_class, net_arch):
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
@pytest.mark.parametrize("model_class", [A2C, DQN, PPO, SAC, TD3])
@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
def test_custom_optimizer(model_class, optimizer_kwargs):
# Use different environment for DQN
if model_class is DQN:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"
kwargs = {}
if model_class in {DQN, SAC, TD3}:
kwargs = dict(learning_starts=100)
@ -38,7 +44,7 @@ def test_custom_optimizer(model_class, optimizer_kwargs):
kwargs = dict(n_steps=64)
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs, **kwargs).learn(300)
_ = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, **kwargs).learn(300)
def test_tf_like_rmsprop_optimizer():
@ -49,3 +55,10 @@ def test_tf_like_rmsprop_optimizer():
def test_dqn_custom_policy():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_not_shared_features_extractor(model_class):
policy_kwargs = dict(net_arch=[12, dict(vf=[16], pi=[8])], share_features_extractor=False)
with pytest.raises(ValueError):
model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs)