From 61d4e2b759a536e48639e8967707662151c84fb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 29 Nov 2022 15:16:02 +0100 Subject: [PATCH] Fix unwrap_vec_normalize type hint --- stable_baselines3/common/vec_env/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 3880fbd..be4326e 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -1,7 +1,7 @@ # flake8: noqa F401 import typing from copy import deepcopy -from typing import Optional, Type, Union +from typing import Optional, Type, TypeVar, Union from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv @@ -19,8 +19,10 @@ from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder if typing.TYPE_CHECKING: from stable_baselines3.common.type_aliases import GymEnv +VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper) -def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: + +def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]: """ Retrieve a ``VecEnvWrapper`` object by recursively searching. @@ -41,7 +43,7 @@ def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize] :param env: :return: """ - return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type + return unwrap_vec_wrapper(env, VecNormalize) def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: