diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 01fc719..67b5591 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -42,6 +42,7 @@ Others: - Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests - Fixed ``tests/test_distributions.py`` type hint - Fixed ``stable_baselines3/common/type_aliases.py`` type hint +- Fixed ``stable_baselines3/common/torch_layers.py`` type hint - Fixed ``stable_baselines3/common/env_util.py`` type hint - Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong) diff --git a/setup.cfg b/setup.cfg index 733b5c3..045638c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,6 @@ exclude = (?x)( | stable_baselines3/common/preprocessing.py$ | stable_baselines3/common/save_util.py$ | stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$ - | stable_baselines3/common/torch_layers.py$ | stable_baselines3/common/utils.py$ | stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/base_vec_env.py$ diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 2ce0cc1..5de2af1 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -28,9 +28,6 @@ class BaseFeaturesExtractor(nn.Module): def features_dim(self) -> int: return self._features_dim - def forward(self, observations: th.Tensor) -> th.Tensor: - raise NotImplementedError() - class FlattenExtractor(BaseFeaturesExtractor): """ @@ -173,9 +170,11 @@ class MlpExtractor(nn.Module): ): super().__init__() device = get_device(device) - shared_net, policy_net, value_net = [], [], [] - policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network - value_only_layers = [] # Layer sizes of the network that only belongs to the value network + shared_net: List[nn.Module] = [] + policy_net: List[nn.Module] = [] + value_net: List[nn.Module] = [] + policy_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the policy network + value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network last_layer_dim_shared = feature_dim # Iterate through the shared layers and build the shared parts of the network @@ -254,7 +253,7 @@ class CombinedExtractor(BaseFeaturesExtractor): # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! super().__init__(observation_space, features_dim=1) - extractors = {} + extractors: Dict[str, nn.Module] = {} total_concat_size = 0 for key, subspace in observation_space.spaces.items():