Fix stable_baselines3/common/torch_layers.py type hint (#1191)

* Remove torch layers from mypy exclude

* Make torch layers mypy compliant

* Extra type specification

* Update changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-11-29 23:46:32 +01:00 committed by GitHub
parent 852d635742
commit 002850f8ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 8 deletions

View file

@ -42,6 +42,7 @@ Others:
- Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests
- Fixed ``tests/test_distributions.py`` type hint
- Fixed ``stable_baselines3/common/type_aliases.py`` type hint
- Fixed ``stable_baselines3/common/torch_layers.py`` type hint
- Fixed ``stable_baselines3/common/env_util.py`` type hint
- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong)

View file

@ -47,7 +47,6 @@ exclude = (?x)(
| stable_baselines3/common/preprocessing.py$
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
| stable_baselines3/common/torch_layers.py$
| stable_baselines3/common/utils.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/base_vec_env.py$

View file

@ -28,9 +28,6 @@ class BaseFeaturesExtractor(nn.Module):
def features_dim(self) -> int:
return self._features_dim
def forward(self, observations: th.Tensor) -> th.Tensor:
raise NotImplementedError()
class FlattenExtractor(BaseFeaturesExtractor):
"""
@ -173,9 +170,11 @@ class MlpExtractor(nn.Module):
):
super().__init__()
device = get_device(device)
shared_net, policy_net, value_net = [], [], []
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
shared_net: List[nn.Module] = []
policy_net: List[nn.Module] = []
value_net: List[nn.Module] = []
policy_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network
last_layer_dim_shared = feature_dim
# Iterate through the shared layers and build the shared parts of the network
@ -254,7 +253,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)
extractors = {}
extractors: Dict[str, nn.Module] = {}
total_concat_size = 0
for key, subspace in observation_space.spaces.items():