mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-29 23:07:07 +00:00
Fix stable_baselines3/common/torch_layers.py type hint (#1191)
* Remove torch layers from mypy exclude * Make torch layers mypy compliant * Extra type specification * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
852d635742
commit
002850f8ac
3 changed files with 7 additions and 8 deletions
|
|
@ -42,6 +42,7 @@ Others:
|
|||
- Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests
|
||||
- Fixed ``tests/test_distributions.py`` type hint
|
||||
- Fixed ``stable_baselines3/common/type_aliases.py`` type hint
|
||||
- Fixed ``stable_baselines3/common/torch_layers.py`` type hint
|
||||
- Fixed ``stable_baselines3/common/env_util.py`` type hint
|
||||
- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong)
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ exclude = (?x)(
|
|||
| stable_baselines3/common/preprocessing.py$
|
||||
| stable_baselines3/common/save_util.py$
|
||||
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
|
||||
| stable_baselines3/common/torch_layers.py$
|
||||
| stable_baselines3/common/utils.py$
|
||||
| stable_baselines3/common/vec_env/__init__.py$
|
||||
| stable_baselines3/common/vec_env/base_vec_env.py$
|
||||
|
|
|
|||
|
|
@ -28,9 +28,6 @@ class BaseFeaturesExtractor(nn.Module):
|
|||
def features_dim(self) -> int:
|
||||
return self._features_dim
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FlattenExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
|
|
@ -173,9 +170,11 @@ class MlpExtractor(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
device = get_device(device)
|
||||
shared_net, policy_net, value_net = [], [], []
|
||||
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
|
||||
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
|
||||
shared_net: List[nn.Module] = []
|
||||
policy_net: List[nn.Module] = []
|
||||
value_net: List[nn.Module] = []
|
||||
policy_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the policy network
|
||||
value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network
|
||||
last_layer_dim_shared = feature_dim
|
||||
|
||||
# Iterate through the shared layers and build the shared parts of the network
|
||||
|
|
@ -254,7 +253,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
|
|||
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
|
||||
super().__init__(observation_space, features_dim=1)
|
||||
|
||||
extractors = {}
|
||||
extractors: Dict[str, nn.Module] = {}
|
||||
|
||||
total_concat_size = 0
|
||||
for key, subspace in observation_space.spaces.items():
|
||||
|
|
|
|||
Loading…
Reference in a new issue