mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-25 02:50:59 +00:00
Modified ActorCriticPolicy to support non-shared features extractor (#1148)
* Modified ActorCriticPolicy to support non-shared features extractor * Refactored features extraction with non-shared features extractor in ActorCriticPolicy and updated doc Doc update: added 'warning' on custom policy docs that says that, if the features extractor is non-shared, it's not possible to have shared layers in the mlp_extractor * Moved attrib share_features_extractor in class * Updated custom policy doc for non-shared features extractor * Updated changelog * Made some if-statements more readable if policies.py The if-statements are related to the shared/non-shared features extractor in ActorCritic policies * Simplify implementation and add run test * Keep order in module gain to keep previous results consistents * Fix test * Improved docstring in policies.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Added some tests * feature extractor -> features extractor * Fix test * Fix env_id in test * Make features extractor parameter explicit * Remove duplicate Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
8452106734
commit
2cfcec4f50
11 changed files with 138 additions and 47 deletions
|
|
@ -108,17 +108,19 @@ using ``policy_kwargs`` parameter:
|
|||
Custom Feature Extractor
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
If you want to have a custom feature extractor (e.g. custom CNN when using images), you can define class
|
||||
If you want to have a custom features extractor (e.g. custom CNN when using images), you can define class
|
||||
that derives from ``BaseFeaturesExtractor`` and then pass it to the model when training.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
By default the feature extractor is shared between the actor and the critic to save computation (when applicable).
|
||||
However, this can be changed by defining a custom policy for on-policy algorithms
|
||||
(see `issue #1066 <https://github.com/DLR-RM/stable-baselines3/issues/1066#issuecomment-1246866844>`_
|
||||
for more information) or setting ``share_features_extractor=False`` in the
|
||||
``policy_kwargs`` for off-policy algorithms (and when applicable).
|
||||
By default the features extractor is shared between the actor and the critic to save computation (when applicable).
|
||||
However, this can be changed setting ``share_features_extractor=False`` in the
|
||||
``policy_kwargs`` (both for on-policy and off-policy algorithms).
|
||||
|
||||
|
||||
.. warning::
|
||||
If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
|
@ -174,7 +176,7 @@ Multiple Inputs and Dictionary Observations
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Stable Baselines3 supports handling of multiple inputs by using ``Dict`` Gym space. This can be done using
|
||||
``MultiInputPolicy``, which by default uses the ``CombinedExtractor`` feature extractor to turn multiple
|
||||
``MultiInputPolicy``, which by default uses the ``CombinedExtractor`` features extractor to turn multiple
|
||||
inputs into a single vector, handled by the ``net_arch`` network.
|
||||
|
||||
By default, ``CombinedExtractor`` processes multiple inputs as follows:
|
||||
|
|
@ -184,7 +186,7 @@ By default, ``CombinedExtractor`` processes multiple inputs as follows:
|
|||
2. If input is not an image, flatten it (no layers).
|
||||
3. Concatenate all previous vectors into one long vector and pass it to policy.
|
||||
|
||||
Much like above, you can define custom feature extractors. The following example assumes the environment has two keys in the
|
||||
Much like above, you can define custom features extractors. The following example assumes the environment has two keys in the
|
||||
observation space dictionary: "image" is a (1,H,W) image (channel first), and "vector" is a (D,) dimensional vector. We process "image" with a simple
|
||||
downsampling and "vector" with a single linear layer.
|
||||
|
||||
|
|
@ -319,7 +321,7 @@ If your task requires even more granular control over the policy/value architect
|
|||
class CustomNetwork(nn.Module):
|
||||
"""
|
||||
Custom network for policy and value function.
|
||||
It receives as input the features extracted by the feature extractor.
|
||||
It receives as input the features extracted by the features extractor.
|
||||
|
||||
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
|
||||
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
|
||||
|
|
@ -411,7 +413,7 @@ you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256
|
|||
|
||||
|
||||
.. note::
|
||||
Compared to their on-policy counterparts, no shared layers (other than the feature extractor)
|
||||
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
|
||||
between the actor and the critic are allowed (to prevent issues with target networks).
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.7.0a7 (WIP)
|
||||
Release 1.7.0a8 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -18,6 +18,7 @@ Breaking Changes:
|
|||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Introduced mypy type checking
|
||||
- Added option to have non-shared features extractor between actor and critic in on-policy algorithms (@AlexPasqua)
|
||||
- Added ``with_bias`` argument to ``create_mlp``
|
||||
- Added support for multidimensional ``spaces.MultiBinary`` observations
|
||||
- Features extractors now properly support unnormalized image-like observations (3D tensor)
|
||||
|
|
@ -40,6 +41,7 @@ Bug Fixes:
|
|||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
|
@ -685,8 +687,8 @@ Bug Fixes:
|
|||
- Fix model creation initializing CUDA even when `device="cpu"` is provided
|
||||
- Fix ``check_env`` not checking if the env has a Dict actionspace before calling ``_check_nan`` (@wmmc88)
|
||||
- Update the check for spaces unsupported by Stable Baselines 3 to include checks on the action space (@wmmc88)
|
||||
- Fixed feature extractor bug for target network where the same net was shared instead
|
||||
of being separate. This bug affects ``SAC``, ``DDPG`` and ``TD3`` when using ``CnnPolicy`` (or custom feature extractor)
|
||||
- Fixed features extractor bug for target network where the same net was shared instead
|
||||
of being separate. This bug affects ``SAC``, ``DDPG`` and ``TD3`` when using ``CnnPolicy`` (or custom features extractor)
|
||||
- Fixed a bug when passing an environment when loading a saved model with a ``CnnPolicy``, the passed env was not wrapped properly
|
||||
(the bug was introduced when implementing ``HER`` so it should not be present in previous versions)
|
||||
|
||||
|
|
@ -763,7 +765,7 @@ Others:
|
|||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Added ``StopTrainingOnMaxEpisodes`` details and example (@xicocaio)
|
||||
- Updated custom policy section (added custom feature extractor example)
|
||||
- Updated custom policy section (added custom features extractor example)
|
||||
- Re-enable ``sphinx_autodoc_typehints``
|
||||
- Updated doc style for type hints and remove duplicated type hints
|
||||
|
||||
|
|
@ -801,7 +803,7 @@ Bug Fixes:
|
|||
- Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang)
|
||||
- Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37)
|
||||
- Fixed approximate entropy calculation in PPO and A2C. (@andyshih12)
|
||||
- Fixed DQN target network sharing feature extractor with the main network.
|
||||
- Fixed DQN target network sharing features extractor with the main network.
|
||||
- Fixed storing correct ``dones`` in on-policy algorithm rollout collection. (@andyshih12)
|
||||
- Fixed number of filters in final convolutional layer in NatureCNN to match original implementation.
|
||||
|
||||
|
|
@ -841,7 +843,7 @@ Breaking Changes:
|
|||
- ``render()`` method of ``VecEnvs`` now only accept one argument: ``mode``
|
||||
- Created new file common/torch_layers.py, similar to SB refactoring
|
||||
|
||||
- Contains all PyTorch network layer definitions and feature extractors: ``MlpExtractor``, ``create_mlp``, ``NatureCNN``
|
||||
- Contains all PyTorch network layer definitions and features extractors: ``MlpExtractor``, ``create_mlp``, ``NatureCNN``
|
||||
|
||||
- Renamed ``BaseRLModel`` to ``BaseAlgorithm`` (along with offpolicy and onpolicy variants)
|
||||
- Moved on-policy and off-policy base algorithms to ``common/on_policy_algorithm.py`` and ``common/off_policy_algorithm.py``, respectively.
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
|
|||
if observation_space.shape[non_channel_idx] < 36 or observation_space.shape[1] < 36:
|
||||
warnings.warn(
|
||||
"The minimal resolution for an image is 36x36 for the default `CnnPolicy`. "
|
||||
"You might need to use a custom feature extractor "
|
||||
"You might need to use a custom features extractor "
|
||||
"cf. https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import collections
|
||||
import copy
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
|
@ -117,16 +118,28 @@ class BaseModel(nn.Module):
|
|||
"""Helper method to create a features extractor."""
|
||||
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
|
||||
def extract_features(self, obs: th.Tensor) -> th.Tensor:
|
||||
def extract_features(self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None) -> th.Tensor:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
||||
:param obs:
|
||||
:return:
|
||||
:param obs: The observation
|
||||
:param features_extractor: The features extractor to use. If it is set to None,
|
||||
the features extractor of the policy is used.
|
||||
:return: The features
|
||||
"""
|
||||
assert self.features_extractor is not None, "No features extractor was set"
|
||||
if features_extractor is None:
|
||||
warnings.warn(
|
||||
(
|
||||
"When calling extract_features(), you should explicitely pass a features_extractor as parameter. "
|
||||
"This will be mandatory in Stable-Baselines v1.8.0"
|
||||
),
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
features_extractor = features_extractor or self.features_extractor
|
||||
assert features_extractor is not None, "No features extractor was set"
|
||||
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
|
||||
return self.features_extractor(preprocessed_obs)
|
||||
return features_extractor(preprocessed_obs)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -391,6 +404,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -414,6 +428,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -447,8 +462,20 @@ class ActorCriticPolicy(BasePolicy):
|
|||
self.activation_fn = activation_fn
|
||||
self.ortho_init = ortho_init
|
||||
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
self.share_features_extractor = share_features_extractor
|
||||
self.features_extractor = self.make_features_extractor()
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
if self.share_features_extractor:
|
||||
self.pi_features_extractor = self.features_extractor
|
||||
self.vf_features_extractor = self.features_extractor
|
||||
else:
|
||||
self.pi_features_extractor = self.features_extractor
|
||||
self.vf_features_extractor = self.make_features_extractor()
|
||||
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
|
||||
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
|
||||
raise ValueError(
|
||||
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
|
||||
)
|
||||
|
||||
self.log_std_init = log_std_init
|
||||
dist_kwargs = None
|
||||
|
|
@ -555,6 +582,13 @@ class ActorCriticPolicy(BasePolicy):
|
|||
self.action_net: 0.01,
|
||||
self.value_net: 1,
|
||||
}
|
||||
if not self.share_features_extractor:
|
||||
# Note(antonin): this is to keep SB3 results
|
||||
# consistent, see GH#1148
|
||||
del module_gains[self.features_extractor]
|
||||
module_gains[self.pi_features_extractor] = np.sqrt(2)
|
||||
module_gains[self.vf_features_extractor] = np.sqrt(2)
|
||||
|
||||
for module, gain in module_gains.items():
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
|
||||
|
|
@ -571,7 +605,12 @@ class ActorCriticPolicy(BasePolicy):
|
|||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
if self.share_features_extractor:
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
else:
|
||||
pi_features, vf_features = features
|
||||
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
||||
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
||||
# Evaluate the values for the given observations
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
|
|
@ -580,6 +619,20 @@ class ActorCriticPolicy(BasePolicy):
|
|||
actions = actions.reshape((-1,) + self.action_space.shape)
|
||||
return actions, values, log_prob
|
||||
|
||||
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
||||
:param obs: Observation
|
||||
:return: the output of the features extractor(s)
|
||||
"""
|
||||
if self.share_features_extractor:
|
||||
return super().extract_features(obs, self.features_extractor)
|
||||
else:
|
||||
pi_features = super().extract_features(obs, self.pi_features_extractor)
|
||||
vf_features = super().extract_features(obs, self.vf_features_extractor)
|
||||
return pi_features, vf_features
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
|
||||
"""
|
||||
Retrieve action distribution given the latent codes.
|
||||
|
|
@ -620,14 +673,19 @@ class ActorCriticPolicy(BasePolicy):
|
|||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
||||
:param obs:
|
||||
:param actions:
|
||||
:param obs: Observation
|
||||
:param actions: Actions
|
||||
:return: estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
if self.share_features_extractor:
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
else:
|
||||
pi_features, vf_features = features
|
||||
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
||||
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
values = self.value_net(latent_vf)
|
||||
|
|
@ -641,7 +699,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
:param obs:
|
||||
:return: the action distribution.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
features = super().extract_features(obs, self.pi_features_extractor)
|
||||
latent_pi = self.mlp_extractor.forward_actor(features)
|
||||
return self._get_action_dist_from_latent(latent_pi)
|
||||
|
||||
|
|
@ -649,10 +707,10 @@ class ActorCriticPolicy(BasePolicy):
|
|||
"""
|
||||
Get the estimated values according to the current policy given the observations.
|
||||
|
||||
:param obs:
|
||||
:param obs: Observation
|
||||
:return: the estimated values.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
features = super().extract_features(obs, self.vf_features_extractor)
|
||||
latent_vf = self.mlp_extractor.forward_critic(features)
|
||||
return self.value_net(latent_vf)
|
||||
|
||||
|
|
@ -680,6 +738,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -703,6 +762,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -721,6 +781,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
share_features_extractor,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
|
|
@ -749,7 +810,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
|
|||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: Uses the CombinedExtractor
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
to pass to the features extractor.
|
||||
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -773,6 +835,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
|
|||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -791,6 +854,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
|
|||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
share_features_extractor,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
|
|
@ -858,7 +922,7 @@ class ContinuousCritic(BaseModel):
|
|||
# Learn the features extractor using the policy loss only
|
||||
# when the features_extractor is shared with the actor
|
||||
with th.set_grad_enabled(not self.share_features_extractor):
|
||||
features = self.extract_features(obs)
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
qvalue_input = th.cat([features, actions], dim=1)
|
||||
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
|
||||
|
||||
|
|
@ -869,5 +933,5 @@ class ContinuousCritic(BaseModel):
|
|||
(e.g. when updating the policy in TD3).
|
||||
"""
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
return self.q_networks[0](th.cat([features, actions], dim=1))
|
||||
|
|
|
|||
|
|
@ -146,8 +146,8 @@ def create_mlp(
|
|||
|
||||
class MlpExtractor(nn.Module):
|
||||
"""
|
||||
Constructs an MLP that receives the output from a previous feature extractor (i.e. a CNN) or directly
|
||||
the observations (if no feature extractor is applied) as an input and outputs a latent representation
|
||||
Constructs an MLP that receives the output from a previous features extractor (i.e. a CNN) or directly
|
||||
the observations (if no features extractor is applied) as an input and outputs a latent representation
|
||||
for the policy and a value network.
|
||||
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
|
||||
of them are shared between the policy network and the value network. It is assumed to be a list with the following
|
||||
|
|
@ -251,8 +251,8 @@ class MlpExtractor(nn.Module):
|
|||
|
||||
class CombinedExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
Combined feature extractor for Dict observation spaces.
|
||||
Builds a feature extractor for each key of the space. Input from each space
|
||||
Combined features extractor for Dict observation spaces.
|
||||
Builds a features extractor for each key of the space. Input from each space
|
||||
is fed through a separate submodule (CNN or MLP, depending on input shape),
|
||||
the output features are concatenated and fed through additional MLP network ("combined").
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class QNetwork(BasePolicy):
|
|||
:param obs: Observation
|
||||
:return: The estimated Q-Value for each action.
|
||||
"""
|
||||
return self.q_net(self.extract_features(obs))
|
||||
return self.q_net(self.extract_features(obs, self.features_extractor))
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
q_values = self(observation)
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ class Actor(BasePolicy):
|
|||
:return:
|
||||
Mean, standard deviation and optional keyword arguments.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
latent_pi = self.latent_pi(features)
|
||||
mean_actions = self.mu(latent_pi)
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class Actor(BasePolicy):
|
|||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
# assert deterministic, 'The TD3 actor only outputs deterministic actions'
|
||||
features = self.extract_features(obs)
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
return self.mu(features)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a7
|
||||
1.7.0a8
|
||||
|
|
|
|||
|
|
@ -14,15 +14,25 @@ from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNorm
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
|
||||
def test_cnn(tmp_path, model_class):
|
||||
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||
def test_cnn(tmp_path, model_class, share_features_extractor):
|
||||
SAVE_NAME = "cnn_model.zip"
|
||||
# Fake grayscale with frameskip
|
||||
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||
# to check that the network handle it automatically
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
|
||||
if model_class in {A2C, PPO}:
|
||||
kwargs = dict(n_steps=64)
|
||||
kwargs = dict(
|
||||
n_steps=64,
|
||||
policy_kwargs=dict(
|
||||
share_features_extractor=share_features_extractor,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# share_features_extractor is checked later for offpolicy algorithms
|
||||
if share_features_extractor:
|
||||
return
|
||||
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(
|
||||
|
|
@ -139,10 +149,10 @@ def test_features_extractor_target_net(model_class, share_features_extractor):
|
|||
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
|
||||
if model_class == TD3:
|
||||
assert id(model.policy.actor_target.features_extractor) == id(model.policy.critic_target.features_extractor)
|
||||
# Actor and critic feature extractor should be the same
|
||||
# Actor and critic features extractor should be the same
|
||||
td3_features_extractor_check = check_td3_feature_extractor_match
|
||||
else:
|
||||
# Actor and critic feature extractor should differ same
|
||||
# Actor and critic features extractor should differ same
|
||||
td3_features_extractor_check = check_td3_feature_extractor_differ
|
||||
# Check that the object differ
|
||||
if model_class != DQN:
|
||||
|
|
|
|||
|
|
@ -28,9 +28,15 @@ def test_custom_offpolicy(model_class, net_arch):
|
|||
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
|
||||
@pytest.mark.parametrize("model_class", [A2C, DQN, PPO, SAC, TD3])
|
||||
@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
|
||||
def test_custom_optimizer(model_class, optimizer_kwargs):
|
||||
# Use different environment for DQN
|
||||
if model_class is DQN:
|
||||
env_id = "CartPole-v1"
|
||||
else:
|
||||
env_id = "Pendulum-v1"
|
||||
|
||||
kwargs = {}
|
||||
if model_class in {DQN, SAC, TD3}:
|
||||
kwargs = dict(learning_starts=100)
|
||||
|
|
@ -38,7 +44,7 @@ def test_custom_optimizer(model_class, optimizer_kwargs):
|
|||
kwargs = dict(n_steps=64)
|
||||
|
||||
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
|
||||
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs, **kwargs).learn(300)
|
||||
_ = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, **kwargs).learn(300)
|
||||
|
||||
|
||||
def test_tf_like_rmsprop_optimizer():
|
||||
|
|
@ -49,3 +55,10 @@ def test_tf_like_rmsprop_optimizer():
|
|||
def test_dqn_custom_policy():
|
||||
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
|
||||
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
def test_not_shared_features_extractor(model_class):
|
||||
policy_kwargs = dict(net_arch=[12, dict(vf=[16], pi=[8])], share_features_extractor=False)
|
||||
with pytest.raises(ValueError):
|
||||
model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue