From 19c1a89a3add1bef4eb45ce79625fc3b3e41eee5 Mon Sep 17 00:00:00 2001 From: Anssi Date: Thu, 22 Oct 2020 12:05:52 +0300 Subject: [PATCH] Rename cmd_util to env_util (#197) * Rename cmd_util to env_util * Fix docs and add missing newline * Address comments --- docs/common/cmd_util.rst | 7 -- docs/common/env_util.rst | 7 ++ docs/guide/examples.rst | 4 +- docs/guide/migration.rst | 2 +- docs/index.rst | 2 +- docs/misc/changelog.rst | 1 + docs/modules/a2c.rst | 2 +- docs/modules/ppo.rst | 2 +- stable_baselines3/common/cmd_util.py | 127 ++------------------------- stable_baselines3/common/env_util.py | 120 +++++++++++++++++++++++++ tests/test_callbacks.py | 2 +- tests/test_utils.py | 8 +- 12 files changed, 149 insertions(+), 135 deletions(-) delete mode 100644 docs/common/cmd_util.rst create mode 100644 docs/common/env_util.rst create mode 100644 stable_baselines3/common/env_util.py diff --git a/docs/common/cmd_util.rst b/docs/common/cmd_util.rst deleted file mode 100644 index 0d7945e..0000000 --- a/docs/common/cmd_util.rst +++ /dev/null @@ -1,7 +0,0 @@ -.. _cmd_util: - -Command Utils -========================= - -.. automodule:: stable_baselines3.common.cmd_util - :members: diff --git a/docs/common/env_util.rst b/docs/common/env_util.rst new file mode 100644 index 0000000..742df9c --- /dev/null +++ b/docs/common/env_util.rst @@ -0,0 +1,7 @@ +.. _env_util: + +Environments Utils +========================= + +.. automodule:: stable_baselines3.common.env_util + :members: diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index a6b8040..6c44e38 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -106,7 +106,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments from stable_baselines3 import PPO from stable_baselines3.common.vec_env import SubprocVecEnv - from stable_baselines3.common.cmd_util import make_vec_env + from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.utils import set_random_seed def make_env(env_id, rank, seed=0): @@ -267,7 +267,7 @@ and multiprocessing for you. .. code-block:: python - from stable_baselines3.common.cmd_util import make_atari_env + from stable_baselines3.common.env_util import make_atari_env from stable_baselines3.common.vec_env import VecFrameStack from stable_baselines3 import A2C diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst index 82fbc69..53bc816 100644 --- a/docs/guide/migration.rst +++ b/docs/guide/migration.rst @@ -64,7 +64,7 @@ Utility functions are no longer exported from ``common`` module, you should impo .. code-block:: python - from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env + from stable_baselines3.common.env_util import make_atari_env, make_vec_env from stable_baselines3.common.utils import set_random_seed instead of ``from stable_baselines3.common import make_atari_env`` diff --git a/docs/index.rst b/docs/index.rst index 939655a..76e4816 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -66,7 +66,7 @@ Main Features :caption: Common common/atari_wrappers - common/cmd_util + common/env_util common/distributions common/evaluation common/env_checker diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e86c8e2..f635712 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -9,6 +9,7 @@ Pre-Release 0.10.0a0 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ +- **Warning:** Renamed ``common.cmd_util`` to ``common.env_util`` for clarity (affects ``make_vec_env`` and ``make_atari_env`` functions) New Features: ^^^^^^^^^^^^^ diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 460d1a6..bf223cb 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -54,7 +54,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments. from stable_baselines3 import A2C from stable_baselines3.a2c import MlpPolicy - from stable_baselines3.common.cmd_util import make_vec_env + from stable_baselines3.common.env_util import make_vec_env # Parallel environments env = make_vec_env('CartPole-v1', n_envs=4) diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index 038149d..d54b970 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -55,7 +55,7 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments. from stable_baselines3 import PPO from stable_baselines3.ppo import MlpPolicy - from stable_baselines3.common.cmd_util import make_vec_env + from stable_baselines3.common.env_util import make_vec_env # Parallel environments env = make_vec_env('CartPole-v1', n_envs=4) diff --git a/stable_baselines3/common/cmd_util.py b/stable_baselines3/common/cmd_util.py index 8a9985e..e5a30e5 100644 --- a/stable_baselines3/common/cmd_util.py +++ b/stable_baselines3/common/cmd_util.py @@ -1,120 +1,7 @@ -import os -from typing import Any, Callable, Dict, Optional, Type, Union - -import gym - -from stable_baselines3.common.atari_wrappers import AtariWrapper -from stable_baselines3.common.monitor import Monitor -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv - - -def make_vec_env( - env_id: Union[str, Type[gym.Env]], - n_envs: int = 1, - seed: Optional[int] = None, - start_index: int = 0, - monitor_dir: Optional[str] = None, - wrapper_class: Optional[Callable] = None, - env_kwargs: Optional[Dict[str, Any]] = None, - vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, - vec_env_kwargs: Optional[Dict[str, Any]] = None, -) -> VecEnv: - """ - Create a wrapped, monitored ``VecEnv``. - By default it uses a ``DummyVecEnv`` which is usually faster - than a ``SubprocVecEnv``. - - :param env_id: the environment ID or the environment class - :param n_envs: the number of environments you wish to have in parallel - :param seed: the initial seed for the random number generator - :param start_index: start rank index - :param monitor_dir: Path to a folder where the monitor files will be saved. - If None, no file will be written, however, the env will still be wrapped - in a Monitor wrapper to provide additional information about training. - :param wrapper_class: Additional wrapper to use on the environment. - This can also be a function with single argument that wraps the environment in many things. - :param env_kwargs: Optional keyword argument to pass to the env constructor - :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. - :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. - :return: The wrapped environment - """ - env_kwargs = {} if env_kwargs is None else env_kwargs - vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs - - def make_env(rank): - def _init(): - if isinstance(env_id, str): - env = gym.make(env_id, **env_kwargs) - else: - env = env_id(**env_kwargs) - if seed is not None: - env.seed(seed + rank) - env.action_space.seed(seed + rank) - # Wrap the env in a Monitor wrapper - # to have additional training information - monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None - # Create the monitor folder if needed - if monitor_path is not None: - os.makedirs(monitor_dir, exist_ok=True) - env = Monitor(env, filename=monitor_path) - # Optionally, wrap the environment with the provided wrapper - if wrapper_class is not None: - env = wrapper_class(env) - return env - - return _init - - # No custom VecEnv is passed - if vec_env_cls is None: - # Default: use a DummyVecEnv - vec_env_cls = DummyVecEnv - - return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs) - - -def make_atari_env( - env_id: Union[str, Type[gym.Env]], - n_envs: int = 1, - seed: Optional[int] = None, - start_index: int = 0, - monitor_dir: Optional[str] = None, - wrapper_kwargs: Optional[Dict[str, Any]] = None, - env_kwargs: Optional[Dict[str, Any]] = None, - vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None, - vec_env_kwargs: Optional[Dict[str, Any]] = None, -) -> VecEnv: - """ - Create a wrapped, monitored VecEnv for Atari. - It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games. - - :param env_id: the environment ID or the environment class - :param n_envs: the number of environments you wish to have in parallel - :param seed: the initial seed for the random number generator - :param start_index: start rank index - :param monitor_dir: Path to a folder where the monitor files will be saved. - If None, no file will be written, however, the env will still be wrapped - in a Monitor wrapper to provide additional information about training. - :param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper`` - :param env_kwargs: Optional keyword argument to pass to the env constructor - :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. - :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. - :return: The wrapped environment - """ - if wrapper_kwargs is None: - wrapper_kwargs = {} - - def atari_wrapper(env: gym.Env) -> gym.Env: - env = AtariWrapper(env, **wrapper_kwargs) - return env - - return make_vec_env( - env_id, - n_envs=n_envs, - seed=seed, - start_index=start_index, - monitor_dir=monitor_dir, - wrapper_class=atari_wrapper, - env_kwargs=env_kwargs, - vec_env_cls=vec_env_cls, - vec_env_kwargs=vec_env_kwargs, - ) +import warnings + +from stable_baselines3.common.env_util import * # noqa: F403 + +warnings.warn( + "Module ``common.cmd_util`` has been renamed to ``common.env_util`` and will be removed in the future.", FutureWarning +) diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py new file mode 100644 index 0000000..8a9985e --- /dev/null +++ b/stable_baselines3/common/env_util.py @@ -0,0 +1,120 @@ +import os +from typing import Any, Callable, Dict, Optional, Type, Union + +import gym + +from stable_baselines3.common.atari_wrappers import AtariWrapper +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv + + +def make_vec_env( + env_id: Union[str, Type[gym.Env]], + n_envs: int = 1, + seed: Optional[int] = None, + start_index: int = 0, + monitor_dir: Optional[str] = None, + wrapper_class: Optional[Callable] = None, + env_kwargs: Optional[Dict[str, Any]] = None, + vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, + vec_env_kwargs: Optional[Dict[str, Any]] = None, +) -> VecEnv: + """ + Create a wrapped, monitored ``VecEnv``. + By default it uses a ``DummyVecEnv`` which is usually faster + than a ``SubprocVecEnv``. + + :param env_id: the environment ID or the environment class + :param n_envs: the number of environments you wish to have in parallel + :param seed: the initial seed for the random number generator + :param start_index: start rank index + :param monitor_dir: Path to a folder where the monitor files will be saved. + If None, no file will be written, however, the env will still be wrapped + in a Monitor wrapper to provide additional information about training. + :param wrapper_class: Additional wrapper to use on the environment. + This can also be a function with single argument that wraps the environment in many things. + :param env_kwargs: Optional keyword argument to pass to the env constructor + :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. + :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. + :return: The wrapped environment + """ + env_kwargs = {} if env_kwargs is None else env_kwargs + vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs + + def make_env(rank): + def _init(): + if isinstance(env_id, str): + env = gym.make(env_id, **env_kwargs) + else: + env = env_id(**env_kwargs) + if seed is not None: + env.seed(seed + rank) + env.action_space.seed(seed + rank) + # Wrap the env in a Monitor wrapper + # to have additional training information + monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None + # Create the monitor folder if needed + if monitor_path is not None: + os.makedirs(monitor_dir, exist_ok=True) + env = Monitor(env, filename=monitor_path) + # Optionally, wrap the environment with the provided wrapper + if wrapper_class is not None: + env = wrapper_class(env) + return env + + return _init + + # No custom VecEnv is passed + if vec_env_cls is None: + # Default: use a DummyVecEnv + vec_env_cls = DummyVecEnv + + return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs) + + +def make_atari_env( + env_id: Union[str, Type[gym.Env]], + n_envs: int = 1, + seed: Optional[int] = None, + start_index: int = 0, + monitor_dir: Optional[str] = None, + wrapper_kwargs: Optional[Dict[str, Any]] = None, + env_kwargs: Optional[Dict[str, Any]] = None, + vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None, + vec_env_kwargs: Optional[Dict[str, Any]] = None, +) -> VecEnv: + """ + Create a wrapped, monitored VecEnv for Atari. + It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games. + + :param env_id: the environment ID or the environment class + :param n_envs: the number of environments you wish to have in parallel + :param seed: the initial seed for the random number generator + :param start_index: start rank index + :param monitor_dir: Path to a folder where the monitor files will be saved. + If None, no file will be written, however, the env will still be wrapped + in a Monitor wrapper to provide additional information about training. + :param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper`` + :param env_kwargs: Optional keyword argument to pass to the env constructor + :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. + :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. + :return: The wrapped environment + """ + if wrapper_kwargs is None: + wrapper_kwargs = {} + + def atari_wrapper(env: gym.Env) -> gym.Env: + env = AtariWrapper(env, **wrapper_kwargs) + return env + + return make_vec_env( + env_id, + n_envs=n_envs, + seed=seed, + start_index=start_index, + monitor_dir=monitor_dir, + wrapper_class=atari_wrapper, + env_kwargs=env_kwargs, + vec_env_cls=vec_env_cls, + vec_env_kwargs=vec_env_kwargs, + ) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 110983b..2f5259b 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -13,7 +13,7 @@ from stable_baselines3.common.callbacks import ( StopTrainingOnMaxEpisodes, StopTrainingOnRewardThreshold, ) -from stable_baselines3.common.cmd_util import make_vec_env +from stable_baselines3.common.env_util import make_vec_env @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4f8b400..b3555d8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,7 +8,7 @@ import torch as th from stable_baselines3 import A2C from stable_baselines3.common.atari_wrappers import ClipRewardEnv -from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env +from stable_baselines3.common.env_util import make_atari_env, make_vec_env from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise @@ -190,3 +190,9 @@ def test_zip_strict(): # same length, should not raise an error for _, _ in zip_strict(list_a, list_b[: len(list_a)]): pass + + +def test_cmd_util_rename(): + """Test that importing cmd_util still works but raises warning""" + with pytest.warns(FutureWarning): + from stable_baselines3.common.cmd_util import make_vec_env