From d672008a32941ef515107fd2b27cc0f72d90eaa0 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 23 Oct 2023 20:12:51 +0200 Subject: [PATCH] Update dependencies (remove sphinx type hint plugin), protect type aliases --- docs/conf.py | 3 ++- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/callbacks.py | 7 +++++-- stable_baselines3/common/type_aliases.py | 12 +++++++----- stable_baselines3/version.txt | 2 +- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 8aeae8b..bd63657 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -64,7 +64,6 @@ release = __version__ # ones. extensions = [ "sphinx.ext.autodoc", - "sphinx_autodoc_typehints", "sphinx.ext.autosummary", "sphinx.ext.mathjax", "sphinx.ext.ifconfig", @@ -73,6 +72,8 @@ extensions = [ # 'sphinx.ext.doctest' ] +autodoc_typehints = "description" + if enable_spell_check: extensions.append("sphinxcontrib.spelling") diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 616c859..ffe1d7e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.2.0a7 (WIP) +Release 2.2.0a8 (WIP) -------------------------- Breaking Changes: @@ -55,6 +55,7 @@ Others: - Fixed ``stable_baselines3/her/her_replay_buffer.py`` type hints - Buffers do no call an additional ``.copy()`` when storing new transitions - Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument +- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``) Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 5089bba..2898df8 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -1,7 +1,7 @@ import os import warnings from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import gymnasium as gym import numpy as np @@ -19,10 +19,13 @@ except ImportError: # if the progress bar is used tqdm = None -from stable_baselines3.common import base_class + from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization +if TYPE_CHECKING: + from stable_baselines3.common import base_class + class BaseCallback(ABC): """ diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 2f98ee1..4a0a878 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,15 +1,17 @@ """Common aliases for type hints""" - from enum import Enum -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union import gymnasium as gym import numpy as np import torch as th -from stable_baselines3.common import callbacks, vec_env +# Avoid circular imports, we use type hint as string to avoid it too +if TYPE_CHECKING: + from stable_baselines3.common.callbacks import BaseCallback + from stable_baselines3.common.vec_env import VecEnv -GymEnv = Union[gym.Env, vec_env.VecEnv] +GymEnv = Union[gym.Env, "VecEnv"] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] GymResetReturn = Tuple[GymObs, Dict] AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]] @@ -17,7 +19,7 @@ GymStepReturn = Tuple[GymObs, float, bool, bool, Dict] AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] -MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] +MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"] # A schedule takes the remaining progress as input # and ouputs a scalar (e.g. learning rate, clip range, ...) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 9b9407a..f1f23b3 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0a7 +2.2.0a8