Fix unwrap_vec_normalize type hint

This commit is contained in:
Quentin Gallouédec 2022-11-29 15:16:02 +01:00
parent e071d1a382
commit 61d4e2b759

View file

@ -1,7 +1,7 @@
# flake8: noqa F401
import typing
from copy import deepcopy
from typing import Optional, Type, Union
from typing import Optional, Type, TypeVar, Union
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
@ -19,8 +19,10 @@ from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
if typing.TYPE_CHECKING:
from stable_baselines3.common.type_aliases import GymEnv
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
@ -41,7 +43,7 @@ def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]
:param env:
:return:
"""
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
return unwrap_vec_wrapper(env, VecNormalize)
def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: