mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-26 03:01:19 +00:00
Remove deprecated usage of feature extractor (#1296)
* Remove deprecated usage of feature extractor * Update changelog and version * Update changelog.rst
This commit is contained in:
parent
12e9917c24
commit
085bdd5a68
3 changed files with 6 additions and 17 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.8.0a5 (WIP)
|
||||
Release 1.8.0a6 (WIP)
|
||||
--------------------------
|
||||
|
||||
|
||||
|
|
@ -12,6 +12,7 @@ Breaking Changes:
|
|||
^^^^^^^^^^^^^^^^^
|
||||
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
|
||||
- Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed)
|
||||
- You must now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -118,26 +118,14 @@ class BaseModel(nn.Module):
|
|||
"""Helper method to create a features extractor."""
|
||||
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
|
||||
def extract_features(self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None) -> th.Tensor:
|
||||
def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
||||
:param obs: The observation
|
||||
:param features_extractor: The features extractor to use. If it is set to None,
|
||||
the features extractor of the policy is used.
|
||||
:return: The features
|
||||
:param features_extractor: The features extractor to use.
|
||||
:return: The extracted features
|
||||
"""
|
||||
if features_extractor is None:
|
||||
warnings.warn(
|
||||
(
|
||||
"When calling extract_features(), you should explicitely pass a features_extractor as parameter. "
|
||||
"This will be mandatory in Stable-Baselines v1.8.0"
|
||||
),
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
features_extractor = features_extractor or self.features_extractor
|
||||
assert features_extractor is not None, "No features extractor was set"
|
||||
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
|
||||
return features_extractor(preprocessed_obs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a5
|
||||
1.8.0a6
|
||||
|
|
|
|||
Loading…
Reference in a new issue