diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c3c6629..c2449f0 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.8.0a5 (WIP) +Release 1.8.0a6 (WIP) -------------------------- @@ -12,6 +12,7 @@ Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed shared layers in ``mlp_extractor`` (@AlexPasqua) - Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed) +- You must now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()`` New Features: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 457274a..795ed47 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -118,26 +118,14 @@ class BaseModel(nn.Module): """Helper method to create a features extractor.""" return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - def extract_features(self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None) -> th.Tensor: + def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor: """ Preprocess the observation if needed and extract features. :param obs: The observation - :param features_extractor: The features extractor to use. If it is set to None, - the features extractor of the policy is used. - :return: The features + :param features_extractor: The features extractor to use. + :return: The extracted features """ - if features_extractor is None: - warnings.warn( - ( - "When calling extract_features(), you should explicitely pass a features_extractor as parameter. " - "This will be mandatory in Stable-Baselines v1.8.0" - ), - DeprecationWarning, - ) - - features_extractor = features_extractor or self.features_extractor - assert features_extractor is not None, "No features extractor was set" preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) return features_extractor(preprocessed_obs) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 3d6a0b8..838407c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0a5 +1.8.0a6