mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Fix unwrap_vec_normalize type hint
This commit is contained in:
parent
e071d1a382
commit
61d4e2b759
1 changed files with 5 additions and 3 deletions
|
|
@ -1,7 +1,7 @@
|
|||
# flake8: noqa F401
|
||||
import typing
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Type, Union
|
||||
from typing import Optional, Type, TypeVar, Union
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
|
|
@ -19,8 +19,10 @@ from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
|
|||
if typing.TYPE_CHECKING:
|
||||
from stable_baselines3.common.type_aliases import GymEnv
|
||||
|
||||
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
|
||||
|
||||
def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
|
||||
|
||||
def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
|
||||
"""
|
||||
Retrieve a ``VecEnvWrapper`` object by recursively searching.
|
||||
|
||||
|
|
@ -41,7 +43,7 @@ def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]
|
|||
:param env:
|
||||
:return:
|
||||
"""
|
||||
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
|
||||
return unwrap_vec_wrapper(env, VecNormalize)
|
||||
|
||||
|
||||
def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in a new issue