mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
Rename cmd_util to env_util (#197)
* Rename cmd_util to env_util * Fix docs and add missing newline * Address comments
This commit is contained in:
parent
856da19609
commit
19c1a89a3a
12 changed files with 149 additions and 135 deletions
|
|
@ -1,7 +0,0 @@
|
|||
.. _cmd_util:
|
||||
|
||||
Command Utils
|
||||
=========================
|
||||
|
||||
.. automodule:: stable_baselines3.common.cmd_util
|
||||
:members:
|
||||
7
docs/common/env_util.rst
Normal file
7
docs/common/env_util.rst
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
.. _env_util:
|
||||
|
||||
Environments Utils
|
||||
=========================
|
||||
|
||||
.. automodule:: stable_baselines3.common.env_util
|
||||
:members:
|
||||
|
|
@ -106,7 +106,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
|
|||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.cmd_util import make_vec_env
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
||||
def make_env(env_id, rank, seed=0):
|
||||
|
|
@ -267,7 +267,7 @@ and multiprocessing for you.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3.common.cmd_util import make_atari_env
|
||||
from stable_baselines3.common.env_util import make_atari_env
|
||||
from stable_baselines3.common.vec_env import VecFrameStack
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ Utility functions are no longer exported from ``common`` module, you should impo
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
|
||||
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
||||
instead of ``from stable_baselines3.common import make_atari_env``
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ Main Features
|
|||
:caption: Common
|
||||
|
||||
common/atari_wrappers
|
||||
common/cmd_util
|
||||
common/env_util
|
||||
common/distributions
|
||||
common/evaluation
|
||||
common/env_checker
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ Pre-Release 0.10.0a0 (WIP)
|
|||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- **Warning:** Renamed ``common.cmd_util`` to ``common.env_util`` for clarity (affects ``make_vec_env`` and ``make_atari_env`` functions)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.a2c import MlpPolicy
|
||||
from stable_baselines3.common.cmd_util import make_vec_env
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
# Parallel environments
|
||||
env = make_vec_env('CartPole-v1', n_envs=4)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments.
|
|||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.ppo import MlpPolicy
|
||||
from stable_baselines3.common.cmd_util import make_vec_env
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
# Parallel environments
|
||||
env = make_vec_env('CartPole-v1', n_envs=4)
|
||||
|
|
|
|||
|
|
@ -1,120 +1,7 @@
|
|||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
|
||||
import gym
|
||||
|
||||
from stable_baselines3.common.atari_wrappers import AtariWrapper
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
|
||||
|
||||
|
||||
def make_vec_env(
|
||||
env_id: Union[str, Type[gym.Env]],
|
||||
n_envs: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
start_index: int = 0,
|
||||
monitor_dir: Optional[str] = None,
|
||||
wrapper_class: Optional[Callable] = None,
|
||||
env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
|
||||
vec_env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> VecEnv:
|
||||
"""
|
||||
Create a wrapped, monitored ``VecEnv``.
|
||||
By default it uses a ``DummyVecEnv`` which is usually faster
|
||||
than a ``SubprocVecEnv``.
|
||||
|
||||
:param env_id: the environment ID or the environment class
|
||||
:param n_envs: the number of environments you wish to have in parallel
|
||||
:param seed: the initial seed for the random number generator
|
||||
:param start_index: start rank index
|
||||
:param monitor_dir: Path to a folder where the monitor files will be saved.
|
||||
If None, no file will be written, however, the env will still be wrapped
|
||||
in a Monitor wrapper to provide additional information about training.
|
||||
:param wrapper_class: Additional wrapper to use on the environment.
|
||||
This can also be a function with single argument that wraps the environment in many things.
|
||||
:param env_kwargs: Optional keyword argument to pass to the env constructor
|
||||
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
|
||||
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
|
||||
:return: The wrapped environment
|
||||
"""
|
||||
env_kwargs = {} if env_kwargs is None else env_kwargs
|
||||
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
|
||||
|
||||
def make_env(rank):
|
||||
def _init():
|
||||
if isinstance(env_id, str):
|
||||
env = gym.make(env_id, **env_kwargs)
|
||||
else:
|
||||
env = env_id(**env_kwargs)
|
||||
if seed is not None:
|
||||
env.seed(seed + rank)
|
||||
env.action_space.seed(seed + rank)
|
||||
# Wrap the env in a Monitor wrapper
|
||||
# to have additional training information
|
||||
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
|
||||
# Create the monitor folder if needed
|
||||
if monitor_path is not None:
|
||||
os.makedirs(monitor_dir, exist_ok=True)
|
||||
env = Monitor(env, filename=monitor_path)
|
||||
# Optionally, wrap the environment with the provided wrapper
|
||||
if wrapper_class is not None:
|
||||
env = wrapper_class(env)
|
||||
return env
|
||||
|
||||
return _init
|
||||
|
||||
# No custom VecEnv is passed
|
||||
if vec_env_cls is None:
|
||||
# Default: use a DummyVecEnv
|
||||
vec_env_cls = DummyVecEnv
|
||||
|
||||
return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
|
||||
|
||||
|
||||
def make_atari_env(
|
||||
env_id: Union[str, Type[gym.Env]],
|
||||
n_envs: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
start_index: int = 0,
|
||||
monitor_dir: Optional[str] = None,
|
||||
wrapper_kwargs: Optional[Dict[str, Any]] = None,
|
||||
env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
|
||||
vec_env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> VecEnv:
|
||||
"""
|
||||
Create a wrapped, monitored VecEnv for Atari.
|
||||
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
|
||||
|
||||
:param env_id: the environment ID or the environment class
|
||||
:param n_envs: the number of environments you wish to have in parallel
|
||||
:param seed: the initial seed for the random number generator
|
||||
:param start_index: start rank index
|
||||
:param monitor_dir: Path to a folder where the monitor files will be saved.
|
||||
If None, no file will be written, however, the env will still be wrapped
|
||||
in a Monitor wrapper to provide additional information about training.
|
||||
:param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper``
|
||||
:param env_kwargs: Optional keyword argument to pass to the env constructor
|
||||
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
|
||||
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
|
||||
:return: The wrapped environment
|
||||
"""
|
||||
if wrapper_kwargs is None:
|
||||
wrapper_kwargs = {}
|
||||
|
||||
def atari_wrapper(env: gym.Env) -> gym.Env:
|
||||
env = AtariWrapper(env, **wrapper_kwargs)
|
||||
return env
|
||||
|
||||
return make_vec_env(
|
||||
env_id,
|
||||
n_envs=n_envs,
|
||||
seed=seed,
|
||||
start_index=start_index,
|
||||
monitor_dir=monitor_dir,
|
||||
wrapper_class=atari_wrapper,
|
||||
env_kwargs=env_kwargs,
|
||||
vec_env_cls=vec_env_cls,
|
||||
vec_env_kwargs=vec_env_kwargs,
|
||||
)
|
||||
import warnings
|
||||
|
||||
from stable_baselines3.common.env_util import * # noqa: F403
|
||||
|
||||
warnings.warn(
|
||||
"Module ``common.cmd_util`` has been renamed to ``common.env_util`` and will be removed in the future.", FutureWarning
|
||||
)
|
||||
|
|
|
|||
120
stable_baselines3/common/env_util.py
Normal file
120
stable_baselines3/common/env_util.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
|
||||
import gym
|
||||
|
||||
from stable_baselines3.common.atari_wrappers import AtariWrapper
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
|
||||
|
||||
|
||||
def make_vec_env(
|
||||
env_id: Union[str, Type[gym.Env]],
|
||||
n_envs: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
start_index: int = 0,
|
||||
monitor_dir: Optional[str] = None,
|
||||
wrapper_class: Optional[Callable] = None,
|
||||
env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
|
||||
vec_env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> VecEnv:
|
||||
"""
|
||||
Create a wrapped, monitored ``VecEnv``.
|
||||
By default it uses a ``DummyVecEnv`` which is usually faster
|
||||
than a ``SubprocVecEnv``.
|
||||
|
||||
:param env_id: the environment ID or the environment class
|
||||
:param n_envs: the number of environments you wish to have in parallel
|
||||
:param seed: the initial seed for the random number generator
|
||||
:param start_index: start rank index
|
||||
:param monitor_dir: Path to a folder where the monitor files will be saved.
|
||||
If None, no file will be written, however, the env will still be wrapped
|
||||
in a Monitor wrapper to provide additional information about training.
|
||||
:param wrapper_class: Additional wrapper to use on the environment.
|
||||
This can also be a function with single argument that wraps the environment in many things.
|
||||
:param env_kwargs: Optional keyword argument to pass to the env constructor
|
||||
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
|
||||
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
|
||||
:return: The wrapped environment
|
||||
"""
|
||||
env_kwargs = {} if env_kwargs is None else env_kwargs
|
||||
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
|
||||
|
||||
def make_env(rank):
|
||||
def _init():
|
||||
if isinstance(env_id, str):
|
||||
env = gym.make(env_id, **env_kwargs)
|
||||
else:
|
||||
env = env_id(**env_kwargs)
|
||||
if seed is not None:
|
||||
env.seed(seed + rank)
|
||||
env.action_space.seed(seed + rank)
|
||||
# Wrap the env in a Monitor wrapper
|
||||
# to have additional training information
|
||||
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
|
||||
# Create the monitor folder if needed
|
||||
if monitor_path is not None:
|
||||
os.makedirs(monitor_dir, exist_ok=True)
|
||||
env = Monitor(env, filename=monitor_path)
|
||||
# Optionally, wrap the environment with the provided wrapper
|
||||
if wrapper_class is not None:
|
||||
env = wrapper_class(env)
|
||||
return env
|
||||
|
||||
return _init
|
||||
|
||||
# No custom VecEnv is passed
|
||||
if vec_env_cls is None:
|
||||
# Default: use a DummyVecEnv
|
||||
vec_env_cls = DummyVecEnv
|
||||
|
||||
return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
|
||||
|
||||
|
||||
def make_atari_env(
|
||||
env_id: Union[str, Type[gym.Env]],
|
||||
n_envs: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
start_index: int = 0,
|
||||
monitor_dir: Optional[str] = None,
|
||||
wrapper_kwargs: Optional[Dict[str, Any]] = None,
|
||||
env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
|
||||
vec_env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> VecEnv:
|
||||
"""
|
||||
Create a wrapped, monitored VecEnv for Atari.
|
||||
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
|
||||
|
||||
:param env_id: the environment ID or the environment class
|
||||
:param n_envs: the number of environments you wish to have in parallel
|
||||
:param seed: the initial seed for the random number generator
|
||||
:param start_index: start rank index
|
||||
:param monitor_dir: Path to a folder where the monitor files will be saved.
|
||||
If None, no file will be written, however, the env will still be wrapped
|
||||
in a Monitor wrapper to provide additional information about training.
|
||||
:param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper``
|
||||
:param env_kwargs: Optional keyword argument to pass to the env constructor
|
||||
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
|
||||
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
|
||||
:return: The wrapped environment
|
||||
"""
|
||||
if wrapper_kwargs is None:
|
||||
wrapper_kwargs = {}
|
||||
|
||||
def atari_wrapper(env: gym.Env) -> gym.Env:
|
||||
env = AtariWrapper(env, **wrapper_kwargs)
|
||||
return env
|
||||
|
||||
return make_vec_env(
|
||||
env_id,
|
||||
n_envs=n_envs,
|
||||
seed=seed,
|
||||
start_index=start_index,
|
||||
monitor_dir=monitor_dir,
|
||||
wrapper_class=atari_wrapper,
|
||||
env_kwargs=env_kwargs,
|
||||
vec_env_cls=vec_env_cls,
|
||||
vec_env_kwargs=vec_env_kwargs,
|
||||
)
|
||||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.callbacks import (
|
|||
StopTrainingOnMaxEpisodes,
|
||||
StopTrainingOnRewardThreshold,
|
||||
)
|
||||
from stable_baselines3.common.cmd_util import make_vec_env
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch as th
|
|||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.atari_wrappers import ClipRewardEnv
|
||||
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
|
||||
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
|
||||
|
|
@ -190,3 +190,9 @@ def test_zip_strict():
|
|||
# same length, should not raise an error
|
||||
for _, _ in zip_strict(list_a, list_b[: len(list_a)]):
|
||||
pass
|
||||
|
||||
|
||||
def test_cmd_util_rename():
|
||||
"""Test that importing cmd_util still works but raises warning"""
|
||||
with pytest.warns(FutureWarning):
|
||||
from stable_baselines3.common.cmd_util import make_vec_env
|
||||
|
|
|
|||
Loading…
Reference in a new issue