From 21e9994ff99db306e14bfa19ca36f133c7153df4 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 5 Aug 2020 12:12:02 +0200 Subject: [PATCH] Fix double reset and improve typing coverage (#136) * Fix double reset and improve typing coverage * Revert minor edit * Add doc about types --- docs/misc/changelog.rst | 25 ++++ stable_baselines3/common/env_checker.py | 2 +- stable_baselines3/common/evaluation.py | 34 +++-- stable_baselines3/common/vec_env/__init__.py | 33 +++-- .../common/vec_env/base_vec_env.py | 127 +++++++++--------- .../common/vec_env/dummy_vec_env.py | 12 +- .../common/vec_env/vec_frame_stack.py | 12 +- .../common/vec_env/vec_transpose.py | 9 +- stable_baselines3/version.txt | 2 +- 9 files changed, 145 insertions(+), 111 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4d66fa3..8e78485 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,31 @@ Changelog ========== +Pre-Release 0.9.0a0 (WIP) +------------------------------ + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ +- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed + +Bug Fixes: +^^^^^^^^^^ +- Fixed a bug where the environment was reset twice when using ``evaluate_policy`` + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Improve typing coverage of the ``VecEnv`` +- Removed ``AlreadySteppingError`` and ``NotSteppingError`` that were not used + +Documentation: +^^^^^^^^^^^^^^ + Pre-Release 0.8.0 (2020-08-03) ------------------------------ diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 558a9fe..3dc3580 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -70,7 +70,7 @@ def _check_nan(env: gym.Env) -> None: """Check for Inf and NaN using the VecWrapper.""" vec_env = VecCheckNan(DummyVecEnv([lambda: env])) for _ in range(10): - action = [env.action_space.sample()] + action = np.array([env.action_space.sample()]) _, _, _, _ = vec_env.step(action) diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index 6dac4d5..0822c1c 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -1,19 +1,25 @@ -# Copied from stable_baselines +import typing +from typing import Callable, List, Optional, Tuple, Union + +import gym import numpy as np from stable_baselines3.common.vec_env import VecEnv +if typing.TYPE_CHECKING: + from stable_baselines3.common.base_class import BaseAlgorithm + def evaluate_policy( - model, - env, - n_eval_episodes=10, - deterministic=True, - render=False, - callback=None, - reward_threshold=None, - return_episode_rewards=False, -): + model: "BaseAlgorithm", + env: Union[gym.Env, VecEnv], + n_eval_episodes: int = 10, + deterministic: bool = True, + render: bool = False, + callback: Optional[Callable] = None, + reward_threshold: Optional[float] = None, + return_episode_rewards: bool = False, +) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: """ Runs policy for ``n_eval_episodes`` episodes and returns average reward. This is made to work only with one env. @@ -28,7 +34,7 @@ def evaluate_policy( called after each step. :param reward_threshold: (float) Minimum expected reward per episode, this will raise an error if the performance is not met - :param return_episode_rewards: (bool) If True, a list of reward per episode + :param return_episode_rewards: (Optional[float]) If True, a list of reward per episode will be returned instead of the mean. :return: (float, float) Mean reward per episode, std of reward per episode returns ([float], [int]) when ``return_episode_rewards`` is True @@ -37,8 +43,10 @@ def evaluate_policy( assert env.num_envs == 1, "You must pass only one environment when using this function" episode_rewards, episode_lengths = [], [] - for _ in range(n_eval_episodes): - obs = env.reset() + for i in range(n_eval_episodes): + # Avoid double reset, as VecEnv are reset automatically + if not isinstance(env, VecEnv) or i == 0: + obs = env.reset() done, state = False, None episode_reward = 0.0 episode_length = 0 diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 9944130..d1dfeb1 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -1,15 +1,9 @@ # flake8: noqa F401 import typing from copy import deepcopy -from typing import Optional, Union +from typing import Optional, Type, Union -from stable_baselines3.common.vec_env.base_vec_env import ( - AlreadySteppingError, - CloudpickleWrapper, - NotSteppingError, - VecEnv, - VecEnvWrapper, -) +from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan @@ -23,17 +17,28 @@ if typing.TYPE_CHECKING: from stable_baselines3.common.type_aliases import GymEnv +def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: + """ + Retrieve a ``VecEnvWrapper`` object by recursively searching. + + :param env: (gym.Env) + :param vec_wrapper_class: (VecEnvWrapper) + :return: (VecEnvWrapper) + """ + env_tmp = env + while isinstance(env_tmp, VecEnvWrapper): + if isinstance(env_tmp, vec_wrapper_class): + return env_tmp + env_tmp = env_tmp.venv + return None + + def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: """ :param env: (gym.Env) :return: (VecNormalize) """ - env_tmp = env - while isinstance(env_tmp, VecEnvWrapper): - if isinstance(env_tmp, VecNormalize): - return env_tmp - env_tmp = env_tmp.venv - return None + return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type # Define here to avoid circular import diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 6338a30..36b4f7b 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,12 +1,23 @@ import inspect from abc import ABC, abstractmethod -from typing import List, Optional, Sequence, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union import cloudpickle +import gym import numpy as np from stable_baselines3.common import logger +# Define type aliases here to avoid circular import +# Used when we want to access one or more VecEnv +VecEnvIndices = Union[None, int, Iterable[int]] +# VecEnvObs is what is returned by the reset() method +# it contains the observation for each env +VecEnvObs = Union[np.ndarray, Dict[str, Any]] +# VecEnvStepReturn is what is returned by the step() method +# it contains the observation, reward, done, info for each env +VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]] + def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover """ @@ -34,46 +45,24 @@ def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cov return out_image -class AlreadySteppingError(Exception): - """ - Raised when an asynchronous step is running while - step_async() is called again. - """ - - def __init__(self): - msg = "already running an async step" - Exception.__init__(self, msg) - - -class NotSteppingError(Exception): - """ - Raised when an asynchronous step is not running but - step_wait() is called. - """ - - def __init__(self): - msg = "not running an async step" - Exception.__init__(self, msg) - - class VecEnv(ABC): """ An abstract asynchronous, vectorized environment. :param num_envs: (int) the number of environments - :param observation_space: (Gym Space) the observation space - :param action_space: (Gym Space) the action space + :param observation_space: (gym.spaces.Space) the observation space + :param action_space: (gym.spaces.Space) the action space """ metadata = {"render.modes": ["human", "rgb_array"]} - def __init__(self, num_envs, observation_space, action_space): + def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space): self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space @abstractmethod - def reset(self): + def reset(self) -> VecEnvObs: """ Reset all the environments and return an array of observations, or a tuple of observation arrays. @@ -82,12 +71,12 @@ class VecEnv(ABC): be cancelled and step_wait() should not be called until step_async() is invoked again. - :return: ([int] or [float]) observation + :return: (VecEnvObs) observation """ raise NotImplementedError() @abstractmethod - def step_async(self, actions): + def step_async(self, actions: np.ndarray): """ Tell all the environments to start taking a step with the given actions. @@ -99,23 +88,23 @@ class VecEnv(ABC): raise NotImplementedError() @abstractmethod - def step_wait(self): + def step_wait(self) -> VecEnvStepReturn: """ Wait for the step taken with step_async(). - :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information + :return: observation, reward, done, information """ raise NotImplementedError() @abstractmethod - def close(self): + def close(self) -> None: """ Clean up the environment's resources. """ raise NotImplementedError() @abstractmethod - def get_attr(self, attr_name, indices=None): + def get_attr(self, attr_name: str, indices: "VecEnvIndices" = None) -> List[Any]: """ Return attribute from vectorized environment. @@ -126,7 +115,7 @@ class VecEnv(ABC): raise NotImplementedError() @abstractmethod - def set_attr(self, attr_name, value, indices=None): + def set_attr(self, attr_name: str, value: Any, indices: "VecEnvIndices" = None) -> None: """ Set attribute inside vectorized environments. @@ -138,7 +127,7 @@ class VecEnv(ABC): raise NotImplementedError() @abstractmethod - def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + def env_method(self, method_name: str, *method_args, indices: "VecEnvIndices" = None, **method_kwargs) -> List[Any]: """ Call instance methods of vectorized environments. @@ -150,12 +139,12 @@ class VecEnv(ABC): """ raise NotImplementedError() - def step(self, actions): + def step(self, actions: np.ndarray) -> VecEnvStepReturn: """ Step the environments with the given action - :param actions: ([int] or [float]) the action - :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information + :param actions: (np.ndarray) the action + :return: (VecEnvStepReturn) observation, reward, done, information """ self.step_async(actions) return self.step_wait() @@ -166,7 +155,7 @@ class VecEnv(ABC): """ raise NotImplementedError - def render(self, mode: str = "human"): + def render(self, mode: str = "human") -> Optional[np.ndarray]: """ Gym environment rendering @@ -203,25 +192,25 @@ class VecEnv(ABC): pass @property - def unwrapped(self): + def unwrapped(self) -> "VecEnv": if isinstance(self, VecEnvWrapper): return self.venv.unwrapped else: return self - def getattr_depth_check(self, name, already_found): + def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]: """Check if an attribute reference is being hidden in a recursive call to __getattr__ :param name: (str) name of attribute to check for :param already_found: (bool) whether this attribute has already been found in a wrapper - :return: (str or None) name of module whose attribute is being shadowed, if any. + :return: (Optional[str]) name of module whose attribute is being shadowed, if any. """ if hasattr(self, name) and already_found: return f"{type(self).__module__}.{type(self).__name__}" else: return None - def _get_indices(self, indices): + def _get_indices(self, indices: "VecEnvIndices") -> Iterable[int]: """ Convert a flexibly-typed reference to environment indices to an implied list of indices. @@ -240,11 +229,16 @@ class VecEnvWrapper(VecEnv): Vectorized environment base class :param venv: (VecEnv) the vectorized environment to wrap - :param observation_space: (Gym Space) the observation space (can be None to load from venv) - :param action_space: (Gym Space) the action space (can be None to load from venv) + :param observation_space: (Optional[gym.spaces.Space]) the observation space (can be None to load from venv) + :param action_space: (Optional[gym.spaces.Space]) the action space (can be None to load from venv) """ - def __init__(self, venv, observation_space=None, action_space=None): + def __init__( + self, + venv: VecEnv, + observation_space: Optional[gym.spaces.Space] = None, + action_space: Optional[gym.spaces.Space] = None, + ): self.venv = venv VecEnv.__init__( self, @@ -254,27 +248,27 @@ class VecEnvWrapper(VecEnv): ) self.class_attributes = dict(inspect.getmembers(self.__class__)) - def step_async(self, actions): + def step_async(self, actions: np.ndarray): self.venv.step_async(actions) @abstractmethod - def reset(self): + def reset(self) -> VecEnvObs: pass @abstractmethod - def step_wait(self): + def step_wait(self) -> VecEnvStepReturn: pass - def seed(self, seed=None): + def seed(self, seed: Optional[int] = None): return self.venv.seed(seed) - def close(self): + def close(self) -> None: return self.venv.close() - def render(self, mode: str = "human"): + def render(self, mode: str = "human") -> Optional[np.ndarray]: return self.venv.render(mode=mode) - def get_images(self): + def get_images(self) -> Sequence[np.ndarray]: return self.venv.get_images() def get_attr(self, attr_name, indices=None): @@ -286,7 +280,7 @@ class VecEnvWrapper(VecEnv): def env_method(self, method_name, *method_args, indices=None, **method_kwargs): return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Find attribute from wrapped venv(s) if this wrapper does not have it. Useful for accessing attributes from venvs which are wrapped with multiple wrappers which have unique attributes of interest. @@ -302,16 +296,16 @@ class VecEnvWrapper(VecEnv): return self.getattr_recursive(name) - def _get_all_attributes(self): + def _get_all_attributes(self) -> Dict[str, Any]: """Get all (inherited) instance and class attributes - :return: (dict) all_attributes + :return: (Dict[str, Any]) all_attributes """ all_attributes = self.__dict__.copy() all_attributes.update(self.class_attributes) return all_attributes - def getattr_recursive(self, name): + def getattr_recursive(self, name: str): """Recursively check wrappers to find attribute. :param name (str) name of attribute to look for @@ -329,7 +323,7 @@ class VecEnvWrapper(VecEnv): return attr - def getattr_depth_check(self, name, already_found): + def getattr_depth_check(self, name: str, already_found: bool): """See base class. :return: (str or None) name of module whose attribute is being shadowed, if any. @@ -349,16 +343,17 @@ class VecEnvWrapper(VecEnv): class CloudpickleWrapper: - def __init__(self, var): - """ - Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) + """ + Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) - :param var: (Any) the variable you wish to wrap for pickling with cloudpickle - """ + :param var: (Any) the variable you wish to wrap for pickling with cloudpickle + """ + + def __init__(self, var: Any): self.var = var - def __getstate__(self): + def __getstate__(self) -> Any: return cloudpickle.dumps(self.var) - def __setstate__(self, obs): - self.var = cloudpickle.loads(obs) + def __setstate__(self, var: Any) -> None: + self.var = cloudpickle.loads(var) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index c577bd8..95bdc1a 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,7 +1,8 @@ from collections import OrderedDict from copy import deepcopy -from typing import Sequence +from typing import Callable, List, Optional, Sequence +import gym import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv @@ -16,10 +17,11 @@ class DummyVecEnv(VecEnv): This can also be used for RL methods that require a vectorized environment, but that you want a single environments to train with. - :param env_fns: ([Gym Environment]) the list of environments to vectorize + :param env_fns: (List[Callable[[], gym.Env]]) a list of functions + that return environments to vectorize """ - def __init__(self, env_fns): + def __init__(self, env_fns: List[Callable[[], gym.Env]]): self.envs = [fn() for fn in env_fns] env = self.envs[0] VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) @@ -33,7 +35,7 @@ class DummyVecEnv(VecEnv): self.actions = None self.metadata = env.metadata - def step_async(self, actions): + def step_async(self, actions: np.ndarray): self.actions = actions def step_wait(self): @@ -48,7 +50,7 @@ class DummyVecEnv(VecEnv): self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) - def seed(self, seed=None): + def seed(self, seed: Optional[int] = None) -> List[int]: seeds = list() for idx, env in enumerate(self.envs): seeds.append(env.seed(seed + idx)) diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index 5e6b7f2..94199ff 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -1,4 +1,5 @@ import warnings +from typing import Any, Dict, List, Tuple import numpy as np from gym import spaces @@ -18,14 +19,17 @@ class VecFrameStack(VecEnvWrapper): self.venv = venv self.n_stack = n_stack wrapped_obs_space = venv.observation_space + assert isinstance(wrapped_obs_space, spaces.Box), "VecFrameStack only work with gym.spaces.Box observation space" low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1) high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1) self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) VecEnvWrapper.__init__(self, venv, observation_space=observation_space) - def step_wait(self): + def step_wait(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict[str, Any]]]: observations, rewards, dones, infos = self.venv.step_wait() + # Let pytype know that observation is not a dict + assert isinstance(observations, np.ndarray) last_ax_size = observations.shape[-1] self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1) for i, done in enumerate(dones): @@ -40,14 +44,14 @@ class VecFrameStack(VecEnvWrapper): self.stackedobs[..., -observations.shape[-1] :] = observations return self.stackedobs, rewards, dones, infos - def reset(self): + def reset(self) -> np.ndarray: """ Reset all environments """ - obs = self.venv.reset() + obs: np.ndarray = self.venv.reset() # pytype:disable=annotation-type-mismatch self.stackedobs[...] = 0 self.stackedobs[..., -obs.shape[-1] :] = obs return self.stackedobs - def close(self): + def close(self) -> None: self.venv.close() diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index 5dc0364..64b92eb 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -1,13 +1,8 @@ -import typing - import numpy as np from gym import spaces from stable_baselines3.common.preprocessing import is_image_space -from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper - -if typing.TYPE_CHECKING: - from stable_baselines3.common.type_aliases import GymStepReturn # noqa: F401 +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper class VecTransposeImage(VecEnvWrapper): @@ -49,7 +44,7 @@ class VecTransposeImage(VecEnvWrapper): return np.transpose(image, (2, 0, 1)) return np.transpose(image, (0, 3, 1, 2)) - def step_wait(self) -> "GymStepReturn": + def step_wait(self) -> VecEnvStepReturn: observations, rewards, dones, infos = self.venv.step_wait() return self.transpose_image(observations), rewards, dones, infos diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index a3df0a6..657e7c0 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.8.0 +0.9.0a0