Rename cmd_util to env_util (#197)

* Rename cmd_util to env_util

* Fix docs and add missing newline

* Address comments
This commit is contained in:
Anssi 2020-10-22 12:05:52 +03:00 committed by GitHub
parent 856da19609
commit 19c1a89a3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 149 additions and 135 deletions

View file

@ -1,7 +0,0 @@
.. _cmd_util:
Command Utils
=========================
.. automodule:: stable_baselines3.common.cmd_util
:members:

7
docs/common/env_util.rst Normal file
View file

@ -0,0 +1,7 @@
.. _env_util:
Environments Utils
=========================
.. automodule:: stable_baselines3.common.env_util
:members:

View file

@ -106,7 +106,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.cmd_util import make_vec_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import set_random_seed
def make_env(env_id, rank, seed=0):
@ -267,7 +267,7 @@ and multiprocessing for you.
.. code-block:: python
from stable_baselines3.common.cmd_util import make_atari_env
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import A2C

View file

@ -64,7 +64,7 @@ Utility functions are no longer exported from ``common`` module, you should impo
.. code-block:: python
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed
instead of ``from stable_baselines3.common import make_atari_env``

View file

@ -66,7 +66,7 @@ Main Features
:caption: Common
common/atari_wrappers
common/cmd_util
common/env_util
common/distributions
common/evaluation
common/env_checker

View file

@ -9,6 +9,7 @@ Pre-Release 0.10.0a0 (WIP)
Breaking Changes:
^^^^^^^^^^^^^^^^^
- **Warning:** Renamed ``common.cmd_util`` to ``common.env_util`` for clarity (affects ``make_vec_env`` and ``make_atari_env`` functions)
New Features:
^^^^^^^^^^^^^

View file

@ -54,7 +54,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
from stable_baselines3 import A2C
from stable_baselines3.a2c import MlpPolicy
from stable_baselines3.common.cmd_util import make_vec_env
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
env = make_vec_env('CartPole-v1', n_envs=4)

View file

@ -55,7 +55,7 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments.
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.cmd_util import make_vec_env
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
env = make_vec_env('CartPole-v1', n_envs=4)

View file

@ -1,120 +1,7 @@
import os
from typing import Any, Callable, Dict, Optional, Type, Union
import gym
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
def make_vec_env(
env_id: Union[str, Type[gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_class: Optional[Callable] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored ``VecEnv``.
By default it uses a ``DummyVecEnv`` which is usually faster
than a ``SubprocVecEnv``.
:param env_id: the environment ID or the environment class
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_class: Additional wrapper to use on the environment.
This can also be a function with single argument that wraps the environment in many things.
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:return: The wrapped environment
"""
env_kwargs = {} if env_kwargs is None else env_kwargs
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
def make_env(rank):
def _init():
if isinstance(env_id, str):
env = gym.make(env_id, **env_kwargs)
else:
env = env_id(**env_kwargs)
if seed is not None:
env.seed(seed + rank)
env.action_space.seed(seed + rank)
# Wrap the env in a Monitor wrapper
# to have additional training information
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
# Create the monitor folder if needed
if monitor_path is not None:
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, filename=monitor_path)
# Optionally, wrap the environment with the provided wrapper
if wrapper_class is not None:
env = wrapper_class(env)
return env
return _init
# No custom VecEnv is passed
if vec_env_cls is None:
# Default: use a DummyVecEnv
vec_env_cls = DummyVecEnv
return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
def make_atari_env(
env_id: Union[str, Type[gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_kwargs: Optional[Dict[str, Any]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored VecEnv for Atari.
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
:param env_id: the environment ID or the environment class
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper``
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:return: The wrapped environment
"""
if wrapper_kwargs is None:
wrapper_kwargs = {}
def atari_wrapper(env: gym.Env) -> gym.Env:
env = AtariWrapper(env, **wrapper_kwargs)
return env
return make_vec_env(
env_id,
n_envs=n_envs,
seed=seed,
start_index=start_index,
monitor_dir=monitor_dir,
wrapper_class=atari_wrapper,
env_kwargs=env_kwargs,
vec_env_cls=vec_env_cls,
vec_env_kwargs=vec_env_kwargs,
)
import warnings
from stable_baselines3.common.env_util import * # noqa: F403
warnings.warn(
"Module ``common.cmd_util`` has been renamed to ``common.env_util`` and will be removed in the future.", FutureWarning
)

View file

@ -0,0 +1,120 @@
import os
from typing import Any, Callable, Dict, Optional, Type, Union
import gym
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
def make_vec_env(
env_id: Union[str, Type[gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_class: Optional[Callable] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored ``VecEnv``.
By default it uses a ``DummyVecEnv`` which is usually faster
than a ``SubprocVecEnv``.
:param env_id: the environment ID or the environment class
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_class: Additional wrapper to use on the environment.
This can also be a function with single argument that wraps the environment in many things.
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:return: The wrapped environment
"""
env_kwargs = {} if env_kwargs is None else env_kwargs
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
def make_env(rank):
def _init():
if isinstance(env_id, str):
env = gym.make(env_id, **env_kwargs)
else:
env = env_id(**env_kwargs)
if seed is not None:
env.seed(seed + rank)
env.action_space.seed(seed + rank)
# Wrap the env in a Monitor wrapper
# to have additional training information
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
# Create the monitor folder if needed
if monitor_path is not None:
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, filename=monitor_path)
# Optionally, wrap the environment with the provided wrapper
if wrapper_class is not None:
env = wrapper_class(env)
return env
return _init
# No custom VecEnv is passed
if vec_env_cls is None:
# Default: use a DummyVecEnv
vec_env_cls = DummyVecEnv
return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
def make_atari_env(
env_id: Union[str, Type[gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_kwargs: Optional[Dict[str, Any]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored VecEnv for Atari.
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
:param env_id: the environment ID or the environment class
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper``
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:return: The wrapped environment
"""
if wrapper_kwargs is None:
wrapper_kwargs = {}
def atari_wrapper(env: gym.Env) -> gym.Env:
env = AtariWrapper(env, **wrapper_kwargs)
return env
return make_vec_env(
env_id,
n_envs=n_envs,
seed=seed,
start_index=start_index,
monitor_dir=monitor_dir,
wrapper_class=atari_wrapper,
env_kwargs=env_kwargs,
vec_env_cls=vec_env_cls,
vec_env_kwargs=vec_env_kwargs,
)

View file

@ -13,7 +13,7 @@ from stable_baselines3.common.callbacks import (
StopTrainingOnMaxEpisodes,
StopTrainingOnRewardThreshold,
)
from stable_baselines3.common.cmd_util import make_vec_env
from stable_baselines3.common.env_util import make_vec_env
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])

View file

@ -8,7 +8,7 @@ import torch as th
from stable_baselines3 import A2C
from stable_baselines3.common.atari_wrappers import ClipRewardEnv
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
@ -190,3 +190,9 @@ def test_zip_strict():
# same length, should not raise an error
for _, _ in zip_strict(list_a, list_b[: len(list_a)]):
pass
def test_cmd_util_rename():
"""Test that importing cmd_util still works but raises warning"""
with pytest.warns(FutureWarning):
from stable_baselines3.common.cmd_util import make_vec_env