mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-27 22:55:17 +00:00
Fix double reset and improve typing coverage (#136)
* Fix double reset and improve typing coverage * Revert minor edit * Add doc about types
This commit is contained in:
parent
cceffd5ab2
commit
21e9994ff9
9 changed files with 145 additions and 111 deletions
|
|
@ -3,6 +3,31 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.9.0a0 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Improve typing coverage of the ``VecEnv``
|
||||
- Removed ``AlreadySteppingError`` and ``NotSteppingError`` that were not used
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
Pre-Release 0.8.0 (2020-08-03)
|
||||
------------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def _check_nan(env: gym.Env) -> None:
|
|||
"""Check for Inf and NaN using the VecWrapper."""
|
||||
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
|
||||
for _ in range(10):
|
||||
action = [env.action_space.sample()]
|
||||
action = np.array([env.action_space.sample()])
|
||||
_, _, _, _ = vec_env.step(action)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,25 @@
|
|||
# Copied from stable_baselines
|
||||
import typing
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
|
||||
|
||||
def evaluate_policy(
|
||||
model,
|
||||
env,
|
||||
n_eval_episodes=10,
|
||||
deterministic=True,
|
||||
render=False,
|
||||
callback=None,
|
||||
reward_threshold=None,
|
||||
return_episode_rewards=False,
|
||||
):
|
||||
model: "BaseAlgorithm",
|
||||
env: Union[gym.Env, VecEnv],
|
||||
n_eval_episodes: int = 10,
|
||||
deterministic: bool = True,
|
||||
render: bool = False,
|
||||
callback: Optional[Callable] = None,
|
||||
reward_threshold: Optional[float] = None,
|
||||
return_episode_rewards: bool = False,
|
||||
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
|
||||
"""
|
||||
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
|
||||
This is made to work only with one env.
|
||||
|
|
@ -28,7 +34,7 @@ def evaluate_policy(
|
|||
called after each step.
|
||||
:param reward_threshold: (float) Minimum expected reward per episode,
|
||||
this will raise an error if the performance is not met
|
||||
:param return_episode_rewards: (bool) If True, a list of reward per episode
|
||||
:param return_episode_rewards: (Optional[float]) If True, a list of reward per episode
|
||||
will be returned instead of the mean.
|
||||
:return: (float, float) Mean reward per episode, std of reward per episode
|
||||
returns ([float], [int]) when ``return_episode_rewards`` is True
|
||||
|
|
@ -37,8 +43,10 @@ def evaluate_policy(
|
|||
assert env.num_envs == 1, "You must pass only one environment when using this function"
|
||||
|
||||
episode_rewards, episode_lengths = [], []
|
||||
for _ in range(n_eval_episodes):
|
||||
obs = env.reset()
|
||||
for i in range(n_eval_episodes):
|
||||
# Avoid double reset, as VecEnv are reset automatically
|
||||
if not isinstance(env, VecEnv) or i == 0:
|
||||
obs = env.reset()
|
||||
done, state = False, None
|
||||
episode_reward = 0.0
|
||||
episode_length = 0
|
||||
|
|
|
|||
|
|
@ -1,15 +1,9 @@
|
|||
# flake8: noqa F401
|
||||
import typing
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import (
|
||||
AlreadySteppingError,
|
||||
CloudpickleWrapper,
|
||||
NotSteppingError,
|
||||
VecEnv,
|
||||
VecEnvWrapper,
|
||||
)
|
||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
|
||||
|
|
@ -23,17 +17,28 @@ if typing.TYPE_CHECKING:
|
|||
from stable_baselines3.common.type_aliases import GymEnv
|
||||
|
||||
|
||||
def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
|
||||
"""
|
||||
Retrieve a ``VecEnvWrapper`` object by recursively searching.
|
||||
|
||||
:param env: (gym.Env)
|
||||
:param vec_wrapper_class: (VecEnvWrapper)
|
||||
:return: (VecEnvWrapper)
|
||||
"""
|
||||
env_tmp = env
|
||||
while isinstance(env_tmp, VecEnvWrapper):
|
||||
if isinstance(env_tmp, vec_wrapper_class):
|
||||
return env_tmp
|
||||
env_tmp = env_tmp.venv
|
||||
return None
|
||||
|
||||
|
||||
def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
|
||||
"""
|
||||
:param env: (gym.Env)
|
||||
:return: (VecNormalize)
|
||||
"""
|
||||
env_tmp = env
|
||||
while isinstance(env_tmp, VecEnvWrapper):
|
||||
if isinstance(env_tmp, VecNormalize):
|
||||
return env_tmp
|
||||
env_tmp = env_tmp.venv
|
||||
return None
|
||||
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
|
||||
|
||||
|
||||
# Define here to avoid circular import
|
||||
|
|
|
|||
|
|
@ -1,12 +1,23 @@
|
|||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import cloudpickle
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
|
||||
# Define type aliases here to avoid circular import
|
||||
# Used when we want to access one or more VecEnv
|
||||
VecEnvIndices = Union[None, int, Iterable[int]]
|
||||
# VecEnvObs is what is returned by the reset() method
|
||||
# it contains the observation for each env
|
||||
VecEnvObs = Union[np.ndarray, Dict[str, Any]]
|
||||
# VecEnvStepReturn is what is returned by the step() method
|
||||
# it contains the observation, reward, done, info for each env
|
||||
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
|
||||
|
||||
|
||||
def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
|
||||
"""
|
||||
|
|
@ -34,46 +45,24 @@ def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cov
|
|||
return out_image
|
||||
|
||||
|
||||
class AlreadySteppingError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous step is running while
|
||||
step_async() is called again.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
msg = "already running an async step"
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
|
||||
class NotSteppingError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous step is not running but
|
||||
step_wait() is called.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
msg = "not running an async step"
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
|
||||
class VecEnv(ABC):
|
||||
"""
|
||||
An abstract asynchronous, vectorized environment.
|
||||
|
||||
:param num_envs: (int) the number of environments
|
||||
:param observation_space: (Gym Space) the observation space
|
||||
:param action_space: (Gym Space) the action space
|
||||
:param observation_space: (gym.spaces.Space) the observation space
|
||||
:param action_space: (gym.spaces.Space) the action space
|
||||
"""
|
||||
|
||||
metadata = {"render.modes": ["human", "rgb_array"]}
|
||||
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
def reset(self) -> VecEnvObs:
|
||||
"""
|
||||
Reset all the environments and return an array of
|
||||
observations, or a tuple of observation arrays.
|
||||
|
|
@ -82,12 +71,12 @@ class VecEnv(ABC):
|
|||
be cancelled and step_wait() should not be called
|
||||
until step_async() is invoked again.
|
||||
|
||||
:return: ([int] or [float]) observation
|
||||
:return: (VecEnvObs) observation
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def step_async(self, actions):
|
||||
def step_async(self, actions: np.ndarray):
|
||||
"""
|
||||
Tell all the environments to start taking a step
|
||||
with the given actions.
|
||||
|
|
@ -99,23 +88,23 @@ class VecEnv(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def step_wait(self):
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
"""
|
||||
Wait for the step taken with step_async().
|
||||
|
||||
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
|
||||
:return: observation, reward, done, information
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Clean up the environment's resources.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_attr(self, attr_name, indices=None):
|
||||
def get_attr(self, attr_name: str, indices: "VecEnvIndices" = None) -> List[Any]:
|
||||
"""
|
||||
Return attribute from vectorized environment.
|
||||
|
||||
|
|
@ -126,7 +115,7 @@ class VecEnv(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def set_attr(self, attr_name, value, indices=None):
|
||||
def set_attr(self, attr_name: str, value: Any, indices: "VecEnvIndices" = None) -> None:
|
||||
"""
|
||||
Set attribute inside vectorized environments.
|
||||
|
||||
|
|
@ -138,7 +127,7 @@ class VecEnv(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
def env_method(self, method_name: str, *method_args, indices: "VecEnvIndices" = None, **method_kwargs) -> List[Any]:
|
||||
"""
|
||||
Call instance methods of vectorized environments.
|
||||
|
||||
|
|
@ -150,12 +139,12 @@ class VecEnv(ABC):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def step(self, actions):
|
||||
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
|
||||
"""
|
||||
Step the environments with the given action
|
||||
|
||||
:param actions: ([int] or [float]) the action
|
||||
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
|
||||
:param actions: (np.ndarray) the action
|
||||
:return: (VecEnvStepReturn) observation, reward, done, information
|
||||
"""
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
|
@ -166,7 +155,7 @@ class VecEnv(ABC):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def render(self, mode: str = "human"):
|
||||
def render(self, mode: str = "human") -> Optional[np.ndarray]:
|
||||
"""
|
||||
Gym environment rendering
|
||||
|
||||
|
|
@ -203,25 +192,25 @@ class VecEnv(ABC):
|
|||
pass
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
def unwrapped(self) -> "VecEnv":
|
||||
if isinstance(self, VecEnvWrapper):
|
||||
return self.venv.unwrapped
|
||||
else:
|
||||
return self
|
||||
|
||||
def getattr_depth_check(self, name, already_found):
|
||||
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
|
||||
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
|
||||
|
||||
:param name: (str) name of attribute to check for
|
||||
:param already_found: (bool) whether this attribute has already been found in a wrapper
|
||||
:return: (str or None) name of module whose attribute is being shadowed, if any.
|
||||
:return: (Optional[str]) name of module whose attribute is being shadowed, if any.
|
||||
"""
|
||||
if hasattr(self, name) and already_found:
|
||||
return f"{type(self).__module__}.{type(self).__name__}"
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_indices(self, indices):
|
||||
def _get_indices(self, indices: "VecEnvIndices") -> Iterable[int]:
|
||||
"""
|
||||
Convert a flexibly-typed reference to environment indices to an implied list of indices.
|
||||
|
||||
|
|
@ -240,11 +229,16 @@ class VecEnvWrapper(VecEnv):
|
|||
Vectorized environment base class
|
||||
|
||||
:param venv: (VecEnv) the vectorized environment to wrap
|
||||
:param observation_space: (Gym Space) the observation space (can be None to load from venv)
|
||||
:param action_space: (Gym Space) the action space (can be None to load from venv)
|
||||
:param observation_space: (Optional[gym.spaces.Space]) the observation space (can be None to load from venv)
|
||||
:param action_space: (Optional[gym.spaces.Space]) the action space (can be None to load from venv)
|
||||
"""
|
||||
|
||||
def __init__(self, venv, observation_space=None, action_space=None):
|
||||
def __init__(
|
||||
self,
|
||||
venv: VecEnv,
|
||||
observation_space: Optional[gym.spaces.Space] = None,
|
||||
action_space: Optional[gym.spaces.Space] = None,
|
||||
):
|
||||
self.venv = venv
|
||||
VecEnv.__init__(
|
||||
self,
|
||||
|
|
@ -254,27 +248,27 @@ class VecEnvWrapper(VecEnv):
|
|||
)
|
||||
self.class_attributes = dict(inspect.getmembers(self.__class__))
|
||||
|
||||
def step_async(self, actions):
|
||||
def step_async(self, actions: np.ndarray):
|
||||
self.venv.step_async(actions)
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
def reset(self) -> VecEnvObs:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_wait(self):
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
pass
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[int] = None):
|
||||
return self.venv.seed(seed)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
return self.venv.close()
|
||||
|
||||
def render(self, mode: str = "human"):
|
||||
def render(self, mode: str = "human") -> Optional[np.ndarray]:
|
||||
return self.venv.render(mode=mode)
|
||||
|
||||
def get_images(self):
|
||||
def get_images(self) -> Sequence[np.ndarray]:
|
||||
return self.venv.get_images()
|
||||
|
||||
def get_attr(self, attr_name, indices=None):
|
||||
|
|
@ -286,7 +280,7 @@ class VecEnvWrapper(VecEnv):
|
|||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
|
||||
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
|
||||
which have unique attributes of interest.
|
||||
|
|
@ -302,16 +296,16 @@ class VecEnvWrapper(VecEnv):
|
|||
|
||||
return self.getattr_recursive(name)
|
||||
|
||||
def _get_all_attributes(self):
|
||||
def _get_all_attributes(self) -> Dict[str, Any]:
|
||||
"""Get all (inherited) instance and class attributes
|
||||
|
||||
:return: (dict<str, object>) all_attributes
|
||||
:return: (Dict[str, Any]) all_attributes
|
||||
"""
|
||||
all_attributes = self.__dict__.copy()
|
||||
all_attributes.update(self.class_attributes)
|
||||
return all_attributes
|
||||
|
||||
def getattr_recursive(self, name):
|
||||
def getattr_recursive(self, name: str):
|
||||
"""Recursively check wrappers to find attribute.
|
||||
|
||||
:param name (str) name of attribute to look for
|
||||
|
|
@ -329,7 +323,7 @@ class VecEnvWrapper(VecEnv):
|
|||
|
||||
return attr
|
||||
|
||||
def getattr_depth_check(self, name, already_found):
|
||||
def getattr_depth_check(self, name: str, already_found: bool):
|
||||
"""See base class.
|
||||
|
||||
:return: (str or None) name of module whose attribute is being shadowed, if any.
|
||||
|
|
@ -349,16 +343,17 @@ class VecEnvWrapper(VecEnv):
|
|||
|
||||
|
||||
class CloudpickleWrapper:
|
||||
def __init__(self, var):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
|
||||
:param var: (Any) the variable you wish to wrap for pickling with cloudpickle
|
||||
"""
|
||||
:param var: (Any) the variable you wish to wrap for pickling with cloudpickle
|
||||
"""
|
||||
|
||||
def __init__(self, var: Any):
|
||||
self.var = var
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> Any:
|
||||
return cloudpickle.dumps(self.var)
|
||||
|
||||
def __setstate__(self, obs):
|
||||
self.var = cloudpickle.loads(obs)
|
||||
def __setstate__(self, var: Any) -> None:
|
||||
self.var = cloudpickle.loads(var)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
from typing import Callable, List, Optional, Sequence
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
|
||||
|
|
@ -16,10 +17,11 @@ class DummyVecEnv(VecEnv):
|
|||
This can also be used for RL methods that
|
||||
require a vectorized environment, but that you want a single environments to train with.
|
||||
|
||||
:param env_fns: ([Gym Environment]) the list of environments to vectorize
|
||||
:param env_fns: (List[Callable[[], gym.Env]]) a list of functions
|
||||
that return environments to vectorize
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
|
||||
self.envs = [fn() for fn in env_fns]
|
||||
env = self.envs[0]
|
||||
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
||||
|
|
@ -33,7 +35,7 @@ class DummyVecEnv(VecEnv):
|
|||
self.actions = None
|
||||
self.metadata = env.metadata
|
||||
|
||||
def step_async(self, actions):
|
||||
def step_async(self, actions: np.ndarray):
|
||||
self.actions = actions
|
||||
|
||||
def step_wait(self):
|
||||
|
|
@ -48,7 +50,7 @@ class DummyVecEnv(VecEnv):
|
|||
self._save_obs(env_idx, obs)
|
||||
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
seeds = list()
|
||||
for idx, env in enumerate(self.envs):
|
||||
seeds.append(env.seed(seed + idx))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
|
@ -18,14 +19,17 @@ class VecFrameStack(VecEnvWrapper):
|
|||
self.venv = venv
|
||||
self.n_stack = n_stack
|
||||
wrapped_obs_space = venv.observation_space
|
||||
assert isinstance(wrapped_obs_space, spaces.Box), "VecFrameStack only work with gym.spaces.Box observation space"
|
||||
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
|
||||
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
|
||||
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
|
||||
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
|
||||
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
|
||||
|
||||
def step_wait(self):
|
||||
def step_wait(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict[str, Any]]]:
|
||||
observations, rewards, dones, infos = self.venv.step_wait()
|
||||
# Let pytype know that observation is not a dict
|
||||
assert isinstance(observations, np.ndarray)
|
||||
last_ax_size = observations.shape[-1]
|
||||
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
|
||||
for i, done in enumerate(dones):
|
||||
|
|
@ -40,14 +44,14 @@ class VecFrameStack(VecEnvWrapper):
|
|||
self.stackedobs[..., -observations.shape[-1] :] = observations
|
||||
return self.stackedobs, rewards, dones, infos
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> np.ndarray:
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
obs = self.venv.reset()
|
||||
obs: np.ndarray = self.venv.reset() # pytype:disable=annotation-type-mismatch
|
||||
self.stackedobs[...] = 0
|
||||
self.stackedobs[..., -obs.shape[-1] :] = obs
|
||||
return self.stackedobs
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.venv.close()
|
||||
|
|
|
|||
|
|
@ -1,13 +1,8 @@
|
|||
import typing
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.preprocessing import is_image_space
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from stable_baselines3.common.type_aliases import GymStepReturn # noqa: F401
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
|
||||
|
||||
|
||||
class VecTransposeImage(VecEnvWrapper):
|
||||
|
|
@ -49,7 +44,7 @@ class VecTransposeImage(VecEnvWrapper):
|
|||
return np.transpose(image, (2, 0, 1))
|
||||
return np.transpose(image, (0, 3, 1, 2))
|
||||
|
||||
def step_wait(self) -> "GymStepReturn":
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
observations, rewards, dones, infos = self.venv.step_wait()
|
||||
return self.transpose_image(observations), rewards, dones, infos
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.8.0
|
||||
0.9.0a0
|
||||
|
|
|
|||
Loading…
Reference in a new issue