Fix double reset and improve typing coverage (#136)

* Fix double reset and improve typing coverage

* Revert minor edit

* Add doc about types
This commit is contained in:
Antonin RAFFIN 2020-08-05 12:12:02 +02:00 committed by GitHub
parent cceffd5ab2
commit 21e9994ff9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 145 additions and 111 deletions

View file

@ -3,6 +3,31 @@
Changelog
==========
Pre-Release 0.9.0a0 (WIP)
------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed
Bug Fixes:
^^^^^^^^^^
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Improve typing coverage of the ``VecEnv``
- Removed ``AlreadySteppingError`` and ``NotSteppingError`` that were not used
Documentation:
^^^^^^^^^^^^^^
Pre-Release 0.8.0 (2020-08-03)
------------------------------

View file

@ -70,7 +70,7 @@ def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
for _ in range(10):
action = [env.action_space.sample()]
action = np.array([env.action_space.sample()])
_, _, _, _ = vec_env.step(action)

View file

@ -1,19 +1,25 @@
# Copied from stable_baselines
import typing
from typing import Callable, List, Optional, Tuple, Union
import gym
import numpy as np
from stable_baselines3.common.vec_env import VecEnv
if typing.TYPE_CHECKING:
from stable_baselines3.common.base_class import BaseAlgorithm
def evaluate_policy(
model,
env,
n_eval_episodes=10,
deterministic=True,
render=False,
callback=None,
reward_threshold=None,
return_episode_rewards=False,
):
model: "BaseAlgorithm",
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,
render: bool = False,
callback: Optional[Callable] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
This is made to work only with one env.
@ -28,7 +34,7 @@ def evaluate_policy(
called after each step.
:param reward_threshold: (float) Minimum expected reward per episode,
this will raise an error if the performance is not met
:param return_episode_rewards: (bool) If True, a list of reward per episode
:param return_episode_rewards: (Optional[float]) If True, a list of reward per episode
will be returned instead of the mean.
:return: (float, float) Mean reward per episode, std of reward per episode
returns ([float], [int]) when ``return_episode_rewards`` is True
@ -37,8 +43,10 @@ def evaluate_policy(
assert env.num_envs == 1, "You must pass only one environment when using this function"
episode_rewards, episode_lengths = [], []
for _ in range(n_eval_episodes):
obs = env.reset()
for i in range(n_eval_episodes):
# Avoid double reset, as VecEnv are reset automatically
if not isinstance(env, VecEnv) or i == 0:
obs = env.reset()
done, state = False, None
episode_reward = 0.0
episode_length = 0

View file

@ -1,15 +1,9 @@
# flake8: noqa F401
import typing
from copy import deepcopy
from typing import Optional, Union
from typing import Optional, Type, Union
from stable_baselines3.common.vec_env.base_vec_env import (
AlreadySteppingError,
CloudpickleWrapper,
NotSteppingError,
VecEnv,
VecEnvWrapper,
)
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
@ -23,17 +17,28 @@ if typing.TYPE_CHECKING:
from stable_baselines3.common.type_aliases import GymEnv
def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env: (gym.Env)
:param vec_wrapper_class: (VecEnvWrapper)
:return: (VecEnvWrapper)
"""
env_tmp = env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, vec_wrapper_class):
return env_tmp
env_tmp = env_tmp.venv
return None
def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
"""
:param env: (gym.Env)
:return: (VecNormalize)
"""
env_tmp = env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, VecNormalize):
return env_tmp
env_tmp = env_tmp.venv
return None
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
# Define here to avoid circular import

View file

@ -1,12 +1,23 @@
import inspect
from abc import ABC, abstractmethod
from typing import List, Optional, Sequence, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import cloudpickle
import gym
import numpy as np
from stable_baselines3.common import logger
# Define type aliases here to avoid circular import
# Used when we want to access one or more VecEnv
VecEnvIndices = Union[None, int, Iterable[int]]
# VecEnvObs is what is returned by the reset() method
# it contains the observation for each env
VecEnvObs = Union[np.ndarray, Dict[str, Any]]
# VecEnvStepReturn is what is returned by the step() method
# it contains the observation, reward, done, info for each env
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
"""
@ -34,46 +45,24 @@ def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cov
return out_image
class AlreadySteppingError(Exception):
"""
Raised when an asynchronous step is running while
step_async() is called again.
"""
def __init__(self):
msg = "already running an async step"
Exception.__init__(self, msg)
class NotSteppingError(Exception):
"""
Raised when an asynchronous step is not running but
step_wait() is called.
"""
def __init__(self):
msg = "not running an async step"
Exception.__init__(self, msg)
class VecEnv(ABC):
"""
An abstract asynchronous, vectorized environment.
:param num_envs: (int) the number of environments
:param observation_space: (Gym Space) the observation space
:param action_space: (Gym Space) the action space
:param observation_space: (gym.spaces.Space) the observation space
:param action_space: (gym.spaces.Space) the action space
"""
metadata = {"render.modes": ["human", "rgb_array"]}
def __init__(self, num_envs, observation_space, action_space):
def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
@abstractmethod
def reset(self):
def reset(self) -> VecEnvObs:
"""
Reset all the environments and return an array of
observations, or a tuple of observation arrays.
@ -82,12 +71,12 @@ class VecEnv(ABC):
be cancelled and step_wait() should not be called
until step_async() is invoked again.
:return: ([int] or [float]) observation
:return: (VecEnvObs) observation
"""
raise NotImplementedError()
@abstractmethod
def step_async(self, actions):
def step_async(self, actions: np.ndarray):
"""
Tell all the environments to start taking a step
with the given actions.
@ -99,23 +88,23 @@ class VecEnv(ABC):
raise NotImplementedError()
@abstractmethod
def step_wait(self):
def step_wait(self) -> VecEnvStepReturn:
"""
Wait for the step taken with step_async().
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
:return: observation, reward, done, information
"""
raise NotImplementedError()
@abstractmethod
def close(self):
def close(self) -> None:
"""
Clean up the environment's resources.
"""
raise NotImplementedError()
@abstractmethod
def get_attr(self, attr_name, indices=None):
def get_attr(self, attr_name: str, indices: "VecEnvIndices" = None) -> List[Any]:
"""
Return attribute from vectorized environment.
@ -126,7 +115,7 @@ class VecEnv(ABC):
raise NotImplementedError()
@abstractmethod
def set_attr(self, attr_name, value, indices=None):
def set_attr(self, attr_name: str, value: Any, indices: "VecEnvIndices" = None) -> None:
"""
Set attribute inside vectorized environments.
@ -138,7 +127,7 @@ class VecEnv(ABC):
raise NotImplementedError()
@abstractmethod
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
def env_method(self, method_name: str, *method_args, indices: "VecEnvIndices" = None, **method_kwargs) -> List[Any]:
"""
Call instance methods of vectorized environments.
@ -150,12 +139,12 @@ class VecEnv(ABC):
"""
raise NotImplementedError()
def step(self, actions):
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
"""
Step the environments with the given action
:param actions: ([int] or [float]) the action
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
:param actions: (np.ndarray) the action
:return: (VecEnvStepReturn) observation, reward, done, information
"""
self.step_async(actions)
return self.step_wait()
@ -166,7 +155,7 @@ class VecEnv(ABC):
"""
raise NotImplementedError
def render(self, mode: str = "human"):
def render(self, mode: str = "human") -> Optional[np.ndarray]:
"""
Gym environment rendering
@ -203,25 +192,25 @@ class VecEnv(ABC):
pass
@property
def unwrapped(self):
def unwrapped(self) -> "VecEnv":
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
def getattr_depth_check(self, name, already_found):
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
:param name: (str) name of attribute to check for
:param already_found: (bool) whether this attribute has already been found in a wrapper
:return: (str or None) name of module whose attribute is being shadowed, if any.
:return: (Optional[str]) name of module whose attribute is being shadowed, if any.
"""
if hasattr(self, name) and already_found:
return f"{type(self).__module__}.{type(self).__name__}"
else:
return None
def _get_indices(self, indices):
def _get_indices(self, indices: "VecEnvIndices") -> Iterable[int]:
"""
Convert a flexibly-typed reference to environment indices to an implied list of indices.
@ -240,11 +229,16 @@ class VecEnvWrapper(VecEnv):
Vectorized environment base class
:param venv: (VecEnv) the vectorized environment to wrap
:param observation_space: (Gym Space) the observation space (can be None to load from venv)
:param action_space: (Gym Space) the action space (can be None to load from venv)
:param observation_space: (Optional[gym.spaces.Space]) the observation space (can be None to load from venv)
:param action_space: (Optional[gym.spaces.Space]) the action space (can be None to load from venv)
"""
def __init__(self, venv, observation_space=None, action_space=None):
def __init__(
self,
venv: VecEnv,
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
):
self.venv = venv
VecEnv.__init__(
self,
@ -254,27 +248,27 @@ class VecEnvWrapper(VecEnv):
)
self.class_attributes = dict(inspect.getmembers(self.__class__))
def step_async(self, actions):
def step_async(self, actions: np.ndarray):
self.venv.step_async(actions)
@abstractmethod
def reset(self):
def reset(self) -> VecEnvObs:
pass
@abstractmethod
def step_wait(self):
def step_wait(self) -> VecEnvStepReturn:
pass
def seed(self, seed=None):
def seed(self, seed: Optional[int] = None):
return self.venv.seed(seed)
def close(self):
def close(self) -> None:
return self.venv.close()
def render(self, mode: str = "human"):
def render(self, mode: str = "human") -> Optional[np.ndarray]:
return self.venv.render(mode=mode)
def get_images(self):
def get_images(self) -> Sequence[np.ndarray]:
return self.venv.get_images()
def get_attr(self, attr_name, indices=None):
@ -286,7 +280,7 @@ class VecEnvWrapper(VecEnv):
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
which have unique attributes of interest.
@ -302,16 +296,16 @@ class VecEnvWrapper(VecEnv):
return self.getattr_recursive(name)
def _get_all_attributes(self):
def _get_all_attributes(self) -> Dict[str, Any]:
"""Get all (inherited) instance and class attributes
:return: (dict<str, object>) all_attributes
:return: (Dict[str, Any]) all_attributes
"""
all_attributes = self.__dict__.copy()
all_attributes.update(self.class_attributes)
return all_attributes
def getattr_recursive(self, name):
def getattr_recursive(self, name: str):
"""Recursively check wrappers to find attribute.
:param name (str) name of attribute to look for
@ -329,7 +323,7 @@ class VecEnvWrapper(VecEnv):
return attr
def getattr_depth_check(self, name, already_found):
def getattr_depth_check(self, name: str, already_found: bool):
"""See base class.
:return: (str or None) name of module whose attribute is being shadowed, if any.
@ -349,16 +343,17 @@ class VecEnvWrapper(VecEnv):
class CloudpickleWrapper:
def __init__(self, var):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
:param var: (Any) the variable you wish to wrap for pickling with cloudpickle
"""
:param var: (Any) the variable you wish to wrap for pickling with cloudpickle
"""
def __init__(self, var: Any):
self.var = var
def __getstate__(self):
def __getstate__(self) -> Any:
return cloudpickle.dumps(self.var)
def __setstate__(self, obs):
self.var = cloudpickle.loads(obs)
def __setstate__(self, var: Any) -> None:
self.var = cloudpickle.loads(var)

View file

@ -1,7 +1,8 @@
from collections import OrderedDict
from copy import deepcopy
from typing import Sequence
from typing import Callable, List, Optional, Sequence
import gym
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
@ -16,10 +17,11 @@ class DummyVecEnv(VecEnv):
This can also be used for RL methods that
require a vectorized environment, but that you want a single environments to train with.
:param env_fns: ([Gym Environment]) the list of environments to vectorize
:param env_fns: (List[Callable[[], gym.Env]]) a list of functions
that return environments to vectorize
"""
def __init__(self, env_fns):
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
@ -33,7 +35,7 @@ class DummyVecEnv(VecEnv):
self.actions = None
self.metadata = env.metadata
def step_async(self, actions):
def step_async(self, actions: np.ndarray):
self.actions = actions
def step_wait(self):
@ -48,7 +50,7 @@ class DummyVecEnv(VecEnv):
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
def seed(self, seed=None):
def seed(self, seed: Optional[int] = None) -> List[int]:
seeds = list()
for idx, env in enumerate(self.envs):
seeds.append(env.seed(seed + idx))

View file

@ -1,4 +1,5 @@
import warnings
from typing import Any, Dict, List, Tuple
import numpy as np
from gym import spaces
@ -18,14 +19,17 @@ class VecFrameStack(VecEnvWrapper):
self.venv = venv
self.n_stack = n_stack
wrapped_obs_space = venv.observation_space
assert isinstance(wrapped_obs_space, spaces.Box), "VecFrameStack only work with gym.spaces.Box observation space"
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
def step_wait(self):
def step_wait(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict[str, Any]]]:
observations, rewards, dones, infos = self.venv.step_wait()
# Let pytype know that observation is not a dict
assert isinstance(observations, np.ndarray)
last_ax_size = observations.shape[-1]
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
for i, done in enumerate(dones):
@ -40,14 +44,14 @@ class VecFrameStack(VecEnvWrapper):
self.stackedobs[..., -observations.shape[-1] :] = observations
return self.stackedobs, rewards, dones, infos
def reset(self):
def reset(self) -> np.ndarray:
"""
Reset all environments
"""
obs = self.venv.reset()
obs: np.ndarray = self.venv.reset() # pytype:disable=annotation-type-mismatch
self.stackedobs[...] = 0
self.stackedobs[..., -obs.shape[-1] :] = obs
return self.stackedobs
def close(self):
def close(self) -> None:
self.venv.close()

View file

@ -1,13 +1,8 @@
import typing
import numpy as np
from gym import spaces
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
if typing.TYPE_CHECKING:
from stable_baselines3.common.type_aliases import GymStepReturn # noqa: F401
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecTransposeImage(VecEnvWrapper):
@ -49,7 +44,7 @@ class VecTransposeImage(VecEnvWrapper):
return np.transpose(image, (2, 0, 1))
return np.transpose(image, (0, 3, 1, 2))
def step_wait(self) -> "GymStepReturn":
def step_wait(self) -> VecEnvStepReturn:
observations, rewards, dones, infos = self.venv.step_wait()
return self.transpose_image(observations), rewards, dones, infos

View file

@ -1 +1 @@
0.8.0
0.9.0a0