Remove deprecated usage of feature extractor (#1296)

* Remove deprecated usage of feature extractor

* Update changelog and version

* Update changelog.rst
This commit is contained in:
Antonin RAFFIN 2023-02-19 12:53:10 +01:00 committed by GitHub
parent 12e9917c24
commit 085bdd5a68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 17 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.8.0a5 (WIP)
Release 1.8.0a6 (WIP)
--------------------------
@ -12,6 +12,7 @@ Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
- Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed)
- You must now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
New Features:
^^^^^^^^^^^^^

View file

@ -118,26 +118,14 @@ class BaseModel(nn.Module):
"""Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
def extract_features(self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None) -> th.Tensor:
def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
:param obs: The observation
:param features_extractor: The features extractor to use. If it is set to None,
the features extractor of the policy is used.
:return: The features
:param features_extractor: The features extractor to use.
:return: The extracted features
"""
if features_extractor is None:
warnings.warn(
(
"When calling extract_features(), you should explicitely pass a features_extractor as parameter. "
"This will be mandatory in Stable-Baselines v1.8.0"
),
DeprecationWarning,
)
features_extractor = features_extractor or self.features_extractor
assert features_extractor is not None, "No features extractor was set"
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return features_extractor(preprocessed_obs)

View file

@ -1 +1 @@
1.8.0a5
1.8.0a6