diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fbed2d5..f113ea5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -42,6 +42,7 @@ Others: - Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests - Fixed ``tests/test_distributions.py`` type hint - Fixed ``stable_baselines3/common/type_aliases.py`` type hint +- Fixed ``stable_baselines3/common/env_util.py`` type hint - Fixed ``stable_baselines3/common/vec_env/__init__.py`` type hint Documentation: diff --git a/setup.cfg b/setup.cfg index 1714887..36c451a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,6 @@ exclude = (?x)( | stable_baselines3/common/buffers.py$ | stable_baselines3/common/callbacks.py$ | stable_baselines3/common/distributions.py$ - | stable_baselines3/common/env_util.py$ | stable_baselines3/common/envs/bit_flipping_env.py$ | stable_baselines3/common/envs/identity_env.py$ | stable_baselines3/common/envs/multi_input_envs.py$ diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index eb893c6..c85d147 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -116,7 +116,7 @@ def make_atari_env( monitor_dir: Optional[str] = None, wrapper_kwargs: Optional[Dict[str, Any]] = None, env_kwargs: Optional[Dict[str, Any]] = None, - vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None, + vec_env_cls: Optional[Union[Type[DummyVecEnv], Type[SubprocVecEnv]]] = None, vec_env_kwargs: Optional[Dict[str, Any]] = None, monitor_kwargs: Optional[Dict[str, Any]] = None, ) -> VecEnv: @@ -138,22 +138,16 @@ def make_atari_env( :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. :return: The wrapped environment """ - if wrapper_kwargs is None: - wrapper_kwargs = {} - - def atari_wrapper(env: gym.Env) -> gym.Env: - env = AtariWrapper(env, **wrapper_kwargs) - return env - return make_vec_env( env_id, n_envs=n_envs, seed=seed, start_index=start_index, monitor_dir=monitor_dir, - wrapper_class=atari_wrapper, + wrapper_class=AtariWrapper, env_kwargs=env_kwargs, vec_env_cls=vec_env_cls, vec_env_kwargs=vec_env_kwargs, monitor_kwargs=monitor_kwargs, + wrapper_kwargs=wrapper_kwargs, )