mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
* Add `has_attr` for `VecEnv` * Add special case for gymnasium<1.0 * Update changelog.rst * Update black version
501 lines
19 KiB
Python
501 lines
19 KiB
Python
import inspect
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Iterable, Sequence
|
|
from copy import deepcopy
|
|
from typing import Any, Optional, Union
|
|
|
|
import cloudpickle
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
from gymnasium import spaces
|
|
|
|
# Define type aliases here to avoid circular import
|
|
# Used when we want to access one or more VecEnv
|
|
VecEnvIndices = Union[None, int, Iterable[int]]
|
|
# VecEnvObs is what is returned by the reset() method
|
|
# it contains the observation for each env
|
|
VecEnvObs = Union[np.ndarray, dict[str, np.ndarray], tuple[np.ndarray, ...]]
|
|
# VecEnvStepReturn is what is returned by the step() method
|
|
# it contains the observation, reward, done, info for each env
|
|
VecEnvStepReturn = tuple[VecEnvObs, np.ndarray, np.ndarray, list[dict]]
|
|
|
|
|
|
def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
|
|
"""
|
|
Tile N images into one big PxQ image
|
|
(P,Q) are chosen to be as close as possible, and if N
|
|
is square, then P=Q.
|
|
|
|
:param images_nhwc: list or array of images, ndim=4 once turned into array.
|
|
n = batch index, h = height, w = width, c = channel
|
|
:return: img_HWc, ndim=3
|
|
"""
|
|
img_nhwc = np.asarray(images_nhwc)
|
|
n_images, height, width, n_channels = img_nhwc.shape
|
|
# new_height was named H before
|
|
new_height = int(np.ceil(np.sqrt(n_images)))
|
|
# new_width was named W before
|
|
new_width = int(np.ceil(float(n_images) / new_height))
|
|
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
|
|
# img_HWhwc
|
|
out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels))
|
|
# img_HhWwc
|
|
out_image = out_image.transpose(0, 2, 1, 3, 4)
|
|
# img_Hh_Ww_c
|
|
out_image = out_image.reshape((new_height * height, new_width * width, n_channels)) # type: ignore[assignment]
|
|
return out_image
|
|
|
|
|
|
class VecEnv(ABC):
|
|
"""
|
|
An abstract asynchronous, vectorized environment.
|
|
|
|
:param num_envs: Number of environments
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_envs: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
):
|
|
self.num_envs = num_envs
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
# store info returned by the reset method
|
|
self.reset_infos: list[dict[str, Any]] = [{} for _ in range(num_envs)]
|
|
# seeds to be used in the next call to env.reset()
|
|
self._seeds: list[Optional[int]] = [None for _ in range(num_envs)]
|
|
# options to be used in the next call to env.reset()
|
|
self._options: list[dict[str, Any]] = [{} for _ in range(num_envs)]
|
|
|
|
try:
|
|
render_modes = self.get_attr("render_mode")
|
|
except AttributeError:
|
|
warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.")
|
|
render_modes = [None for _ in range(num_envs)]
|
|
|
|
assert all(
|
|
render_mode == render_modes[0] for render_mode in render_modes
|
|
), "render_mode mode should be the same for all environments"
|
|
self.render_mode = render_modes[0]
|
|
|
|
render_modes = []
|
|
if self.render_mode is not None:
|
|
if self.render_mode == "rgb_array":
|
|
# SB3 uses OpenCV for the "human" mode
|
|
render_modes = ["human", "rgb_array"]
|
|
else:
|
|
render_modes = [self.render_mode]
|
|
|
|
self.metadata = {"render_modes": render_modes}
|
|
|
|
def _reset_seeds(self) -> None:
|
|
"""
|
|
Reset the seeds that are going to be used at the next reset.
|
|
"""
|
|
self._seeds = [None for _ in range(self.num_envs)]
|
|
|
|
def _reset_options(self) -> None:
|
|
"""
|
|
Reset the options that are going to be used at the next reset.
|
|
"""
|
|
self._options = [{} for _ in range(self.num_envs)]
|
|
|
|
@abstractmethod
|
|
def reset(self) -> VecEnvObs:
|
|
"""
|
|
Reset all the environments and return an array of
|
|
observations, or a tuple of observation arrays.
|
|
|
|
If step_async is still doing work, that work will
|
|
be cancelled and step_wait() should not be called
|
|
until step_async() is invoked again.
|
|
|
|
:return: observation
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def step_async(self, actions: np.ndarray) -> None:
|
|
"""
|
|
Tell all the environments to start taking a step
|
|
with the given actions.
|
|
Call step_wait() to get the results of the step.
|
|
|
|
You should not call this if a step_async run is
|
|
already pending.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def step_wait(self) -> VecEnvStepReturn:
|
|
"""
|
|
Wait for the step taken with step_async().
|
|
|
|
:return: observation, reward, done, information
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def close(self) -> None:
|
|
"""
|
|
Clean up the environment's resources.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def has_attr(self, attr_name: str) -> bool:
|
|
"""
|
|
Check if an attribute exists for a vectorized environment.
|
|
|
|
:param attr_name: The name of the attribute to check
|
|
:return: True if 'attr_name' exists in all environments
|
|
"""
|
|
# Default implementation, will not work with things that cannot be pickled:
|
|
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49
|
|
try:
|
|
self.get_attr(attr_name)
|
|
return True
|
|
except AttributeError:
|
|
return False
|
|
|
|
@abstractmethod
|
|
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
|
|
"""
|
|
Return attribute from vectorized environment.
|
|
|
|
:param attr_name: The name of the attribute whose value to return
|
|
:param indices: Indices of envs to get attribute from
|
|
:return: List of values of 'attr_name' in all environments
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
|
"""
|
|
Set attribute inside vectorized environments.
|
|
|
|
:param attr_name: The name of attribute to assign new value
|
|
:param value: Value to assign to `attr_name`
|
|
:param indices: Indices of envs to assign value
|
|
:return:
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
|
|
"""
|
|
Call instance methods of vectorized environments.
|
|
|
|
:param method_name: The name of the environment method to invoke.
|
|
:param indices: Indices of envs whose method to call
|
|
:param method_args: Any positional arguments to provide in the call
|
|
:param method_kwargs: Any keyword arguments to provide in the call
|
|
:return: List of items returned by the environment's method call
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
|
|
"""
|
|
Check if environments are wrapped with a given wrapper.
|
|
|
|
:param method_name: The name of the environment method to invoke.
|
|
:param indices: Indices of envs whose method to call
|
|
:param method_args: Any positional arguments to provide in the call
|
|
:param method_kwargs: Any keyword arguments to provide in the call
|
|
:return: True if the env is wrapped, False otherwise, for each env queried.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
|
|
"""
|
|
Step the environments with the given action
|
|
|
|
:param actions: the action
|
|
:return: observation, reward, done, information
|
|
"""
|
|
self.step_async(actions)
|
|
return self.step_wait()
|
|
|
|
def get_images(self) -> Sequence[Optional[np.ndarray]]:
|
|
"""
|
|
Return RGB images from each environment when available
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
|
|
"""
|
|
Gym environment rendering
|
|
|
|
:param mode: the rendering type
|
|
"""
|
|
|
|
if mode == "human" and self.render_mode != mode:
|
|
# Special case, if the render_mode="rgb_array"
|
|
# we can still display that image using opencv
|
|
if self.render_mode != "rgb_array":
|
|
warnings.warn(
|
|
f"You tried to render a VecEnv with mode='{mode}' "
|
|
"but the render mode defined when initializing the environment must be "
|
|
f"'human' or 'rgb_array', not '{self.render_mode}'."
|
|
)
|
|
return None
|
|
|
|
elif mode and self.render_mode != mode:
|
|
warnings.warn(
|
|
f"""Starting from gymnasium v0.26, render modes are determined during the initialization of the environment.
|
|
We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode})
|
|
has to be the same as the environment render mode ({self.render_mode}) which is not the case."""
|
|
)
|
|
return None
|
|
|
|
mode = mode or self.render_mode
|
|
|
|
if mode is None:
|
|
warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.")
|
|
return None
|
|
|
|
# mode == self.render_mode == "human"
|
|
# In that case, we try to call `self.env.render()` but it might
|
|
# crash for subprocesses
|
|
if self.render_mode == "human":
|
|
self.env_method("render")
|
|
return None
|
|
|
|
if mode == "rgb_array" or mode == "human":
|
|
# call the render method of the environments
|
|
images = self.get_images()
|
|
# Create a big image by tiling images from subprocesses
|
|
bigimg = tile_images(images) # type: ignore[arg-type]
|
|
|
|
if mode == "human":
|
|
# Display it using OpenCV
|
|
import cv2
|
|
|
|
cv2.imshow("vecenv", bigimg[:, :, ::-1])
|
|
cv2.waitKey(1)
|
|
else:
|
|
return bigimg
|
|
|
|
else:
|
|
# Other render modes:
|
|
# In that case, we try to call `self.env.render()` but it might
|
|
# crash for subprocesses
|
|
# and we don't return the values
|
|
self.env_method("render")
|
|
return None
|
|
|
|
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
|
|
"""
|
|
Sets the random seeds for all environments, based on a given seed.
|
|
Each individual environment will still get its own seed, by incrementing the given seed.
|
|
WARNING: since gym 0.26, those seeds will only be passed to the environment
|
|
at the next reset.
|
|
|
|
:param seed: The random seed. May be None for completely random seeding.
|
|
:return: Returns a list containing the seeds for each individual env.
|
|
Note that all list elements may be None, if the env does not return anything when being seeded.
|
|
"""
|
|
if seed is None:
|
|
# To ensure that subprocesses have different seeds,
|
|
# we still populate the seed variable when no argument is passed
|
|
seed = int(np.random.randint(0, np.iinfo(np.uint32).max, dtype=np.uint32))
|
|
|
|
self._seeds = [seed + idx for idx in range(self.num_envs)]
|
|
return self._seeds
|
|
|
|
def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None:
|
|
"""
|
|
Set environment options for all environments.
|
|
If a dict is passed instead of a list, the same options will be used for all environments.
|
|
WARNING: Those options will only be passed to the environment at the next reset.
|
|
|
|
:param options: A dictionary of environment options to pass to each environment at the next reset.
|
|
"""
|
|
if options is None:
|
|
options = {}
|
|
# Use deepcopy to avoid side effects
|
|
if isinstance(options, dict):
|
|
self._options = deepcopy([options] * self.num_envs)
|
|
else:
|
|
self._options = deepcopy(options)
|
|
|
|
@property
|
|
def unwrapped(self) -> "VecEnv":
|
|
if isinstance(self, VecEnvWrapper):
|
|
return self.venv.unwrapped
|
|
else:
|
|
return self
|
|
|
|
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
|
|
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
|
|
|
|
:param name: name of attribute to check for
|
|
:param already_found: whether this attribute has already been found in a wrapper
|
|
:return: name of module whose attribute is being shadowed, if any.
|
|
"""
|
|
if hasattr(self, name) and already_found:
|
|
return f"{type(self).__module__}.{type(self).__name__}"
|
|
else:
|
|
return None
|
|
|
|
def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]:
|
|
"""
|
|
Convert a flexibly-typed reference to environment indices to an implied list of indices.
|
|
|
|
:param indices: refers to indices of envs.
|
|
:return: the implied list of indices.
|
|
"""
|
|
if indices is None:
|
|
indices = range(self.num_envs)
|
|
elif isinstance(indices, int):
|
|
indices = [indices]
|
|
return indices
|
|
|
|
|
|
class VecEnvWrapper(VecEnv):
|
|
"""
|
|
Vectorized environment base class
|
|
|
|
:param venv: the vectorized environment to wrap
|
|
:param observation_space: the observation space (can be None to load from venv)
|
|
:param action_space: the action space (can be None to load from venv)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
venv: VecEnv,
|
|
observation_space: Optional[spaces.Space] = None,
|
|
action_space: Optional[spaces.Space] = None,
|
|
):
|
|
self.venv = venv
|
|
|
|
super().__init__(
|
|
num_envs=venv.num_envs,
|
|
observation_space=observation_space or venv.observation_space,
|
|
action_space=action_space or venv.action_space,
|
|
)
|
|
self.class_attributes = dict(inspect.getmembers(self.__class__))
|
|
|
|
def step_async(self, actions: np.ndarray) -> None:
|
|
self.venv.step_async(actions)
|
|
|
|
@abstractmethod
|
|
def reset(self) -> VecEnvObs:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def step_wait(self) -> VecEnvStepReturn:
|
|
pass
|
|
|
|
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
|
|
return self.venv.seed(seed)
|
|
|
|
def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None:
|
|
return self.venv.set_options(options)
|
|
|
|
def close(self) -> None:
|
|
return self.venv.close()
|
|
|
|
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
|
|
return self.venv.render(mode=mode)
|
|
|
|
def get_images(self) -> Sequence[Optional[np.ndarray]]:
|
|
return self.venv.get_images()
|
|
|
|
def has_attr(self, attr_name: str) -> bool:
|
|
return self.venv.has_attr(attr_name)
|
|
|
|
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
|
|
return self.venv.get_attr(attr_name, indices)
|
|
|
|
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
|
return self.venv.set_attr(attr_name, value, indices)
|
|
|
|
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
|
|
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
|
|
|
|
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
|
|
return self.venv.env_is_wrapped(wrapper_class, indices=indices)
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
|
|
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
|
|
which have unique attributes of interest.
|
|
"""
|
|
blocked_class = self.getattr_depth_check(name, already_found=False)
|
|
if blocked_class is not None:
|
|
own_class = f"{type(self).__module__}.{type(self).__name__}"
|
|
error_str = (
|
|
f"Error: Recursive attribute lookup for {name} from {own_class} is "
|
|
f"ambiguous and hides attribute from {blocked_class}"
|
|
)
|
|
raise AttributeError(error_str)
|
|
|
|
return self.getattr_recursive(name)
|
|
|
|
def _get_all_attributes(self) -> dict[str, Any]:
|
|
"""Get all (inherited) instance and class attributes
|
|
|
|
:return: all_attributes
|
|
"""
|
|
all_attributes = self.__dict__.copy()
|
|
all_attributes.update(self.class_attributes)
|
|
return all_attributes
|
|
|
|
def getattr_recursive(self, name: str) -> Any:
|
|
"""Recursively check wrappers to find attribute.
|
|
|
|
:param name: name of attribute to look for
|
|
:return: attribute
|
|
"""
|
|
all_attributes = self._get_all_attributes()
|
|
if name in all_attributes: # attribute is present in this wrapper
|
|
attr = getattr(self, name)
|
|
elif hasattr(self.venv, "getattr_recursive"):
|
|
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
|
|
# to avoid a duplicate call to getattr_depth_check.
|
|
attr = self.venv.getattr_recursive(name)
|
|
else: # attribute not present, child is an unwrapped VecEnv
|
|
attr = getattr(self.venv, name)
|
|
|
|
return attr
|
|
|
|
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
|
|
"""See base class.
|
|
|
|
:return: name of module whose attribute is being shadowed, if any.
|
|
"""
|
|
all_attributes = self._get_all_attributes()
|
|
if name in all_attributes and already_found:
|
|
# this venv's attribute is being hidden because of a higher venv.
|
|
shadowed_wrapper_class: Optional[str] = f"{type(self).__module__}.{type(self).__name__}"
|
|
elif name in all_attributes and not already_found:
|
|
# we have found the first reference to the attribute. Now check for duplicates.
|
|
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
|
|
else:
|
|
# this wrapper does not have the attribute. Keep searching.
|
|
shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)
|
|
|
|
return shadowed_wrapper_class
|
|
|
|
|
|
class CloudpickleWrapper:
|
|
"""
|
|
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
|
|
|
:param var: the variable you wish to wrap for pickling with cloudpickle
|
|
"""
|
|
|
|
def __init__(self, var: Any):
|
|
self.var = var
|
|
|
|
def __getstate__(self) -> Any:
|
|
return cloudpickle.dumps(self.var)
|
|
|
|
def __setstate__(self, var: Any) -> None:
|
|
self.var = cloudpickle.loads(var)
|