mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
359 lines
12 KiB
Python
359 lines
12 KiB
Python
import inspect
|
|
import pickle
|
|
from abc import ABC, abstractmethod
|
|
from typing import Sequence, Optional, List, Union
|
|
|
|
import numpy as np
|
|
import cloudpickle
|
|
|
|
from stable_baselines3.common import logger
|
|
|
|
|
|
def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
|
|
"""
|
|
Tile N images into one big PxQ image
|
|
(P,Q) are chosen to be as close as possible, and if N
|
|
is square, then P=Q.
|
|
|
|
:param img_nhwc: (Sequence[np.ndarray]) list or array of images, ndim=4 once turned into array. img nhwc
|
|
n = batch index, h = height, w = width, c = channel
|
|
:return: (np.ndarray) img_HWc, ndim=3
|
|
"""
|
|
img_nhwc = np.asarray(img_nhwc)
|
|
n_images, height, width, n_channels = img_nhwc.shape
|
|
# new_height was named H before
|
|
new_height = int(np.ceil(np.sqrt(n_images)))
|
|
# new_width was named W before
|
|
new_width = int(np.ceil(float(n_images) / new_height))
|
|
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
|
|
# img_HWhwc
|
|
out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels))
|
|
# img_HhWwc
|
|
out_image = out_image.transpose(0, 2, 1, 3, 4)
|
|
# img_Hh_Ww_c
|
|
out_image = out_image.reshape((new_height * height, new_width * width, n_channels))
|
|
return out_image
|
|
|
|
|
|
class AlreadySteppingError(Exception):
|
|
"""
|
|
Raised when an asynchronous step is running while
|
|
step_async() is called again.
|
|
"""
|
|
|
|
def __init__(self):
|
|
msg = 'already running an async step'
|
|
Exception.__init__(self, msg)
|
|
|
|
|
|
class NotSteppingError(Exception):
|
|
"""
|
|
Raised when an asynchronous step is not running but
|
|
step_wait() is called.
|
|
"""
|
|
|
|
def __init__(self):
|
|
msg = 'not running an async step'
|
|
Exception.__init__(self, msg)
|
|
|
|
|
|
class VecEnv(ABC):
|
|
"""
|
|
An abstract asynchronous, vectorized environment.
|
|
|
|
:param num_envs: (int) the number of environments
|
|
:param observation_space: (Gym Space) the observation space
|
|
:param action_space: (Gym Space) the action space
|
|
"""
|
|
metadata = {
|
|
'render.modes': ['human', 'rgb_array']
|
|
}
|
|
|
|
def __init__(self, num_envs, observation_space, action_space):
|
|
self.num_envs = num_envs
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
|
|
@abstractmethod
|
|
def reset(self):
|
|
"""
|
|
Reset all the environments and return an array of
|
|
observations, or a tuple of observation arrays.
|
|
|
|
If step_async is still doing work, that work will
|
|
be cancelled and step_wait() should not be called
|
|
until step_async() is invoked again.
|
|
|
|
:return: ([int] or [float]) observation
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def step_async(self, actions):
|
|
"""
|
|
Tell all the environments to start taking a step
|
|
with the given actions.
|
|
Call step_wait() to get the results of the step.
|
|
|
|
You should not call this if a step_async run is
|
|
already pending.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def step_wait(self):
|
|
"""
|
|
Wait for the step taken with step_async().
|
|
|
|
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def close(self):
|
|
"""
|
|
Clean up the environment's resources.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def get_attr(self, attr_name, indices=None):
|
|
"""
|
|
Return attribute from vectorized environment.
|
|
|
|
:param attr_name: (str) The name of the attribute whose value to return
|
|
:param indices: (list,int) Indices of envs to get attribute from
|
|
:return: (list) List of values of 'attr_name' in all environments
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def set_attr(self, attr_name, value, indices=None):
|
|
"""
|
|
Set attribute inside vectorized environments.
|
|
|
|
:param attr_name: (str) The name of attribute to assign new value
|
|
:param value: (obj) Value to assign to `attr_name`
|
|
:param indices: (list,int) Indices of envs to assign value
|
|
:return: (NoneType)
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
|
"""
|
|
Call instance methods of vectorized environments.
|
|
|
|
:param method_name: (str) The name of the environment method to invoke.
|
|
:param indices: (list,int) Indices of envs whose method to call
|
|
:param method_args: (tuple) Any positional arguments to provide in the call
|
|
:param method_kwargs: (dict) Any keyword arguments to provide in the call
|
|
:return: (list) List of items returned by the environment's method call
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def step(self, actions):
|
|
"""
|
|
Step the environments with the given action
|
|
|
|
:param actions: ([int] or [float]) the action
|
|
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
|
|
"""
|
|
self.step_async(actions)
|
|
return self.step_wait()
|
|
|
|
def get_images(self) -> Sequence[np.ndarray]:
|
|
"""
|
|
Return RGB images from each environment
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def render(self, mode: str = 'human'):
|
|
"""
|
|
Gym environment rendering
|
|
|
|
:param mode: the rendering type
|
|
"""
|
|
try:
|
|
imgs = self.get_images()
|
|
except NotImplementedError:
|
|
logger.warn(f'Render not defined for {self}')
|
|
return
|
|
|
|
# Create a big image by tiling images from subprocesses
|
|
bigimg = tile_images(imgs)
|
|
if mode == 'human':
|
|
import cv2 # pytype:disable=import-error
|
|
cv2.imshow('vecenv', bigimg[:, :, ::-1])
|
|
cv2.waitKey(1)
|
|
elif mode == 'rgb_array':
|
|
return bigimg
|
|
else:
|
|
raise NotImplementedError(f'Render mode {mode} is not supported by VecEnvs')
|
|
|
|
@abstractmethod
|
|
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
|
"""
|
|
Sets the random seeds for all environments, based on a given seed.
|
|
Each individual environment will still get its own seed, by incrementing the given seed.
|
|
|
|
:param seed: (Optional[int]) The random seed. May be None for completely random seeding.
|
|
:return: (List[Union[None, int]]) Returns a list containing the seeds for each individual env.
|
|
Note that all list elements may be None, if the env does not return anything when being seeded.
|
|
"""
|
|
pass
|
|
|
|
@property
|
|
def unwrapped(self):
|
|
if isinstance(self, VecEnvWrapper):
|
|
return self.venv.unwrapped
|
|
else:
|
|
return self
|
|
|
|
def getattr_depth_check(self, name, already_found):
|
|
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
|
|
|
|
:param name: (str) name of attribute to check for
|
|
:param already_found: (bool) whether this attribute has already been found in a wrapper
|
|
:return: (str or None) name of module whose attribute is being shadowed, if any.
|
|
"""
|
|
if hasattr(self, name) and already_found:
|
|
return f"{type(self).__module__}.{type(self).__name__}"
|
|
else:
|
|
return None
|
|
|
|
def _get_indices(self, indices):
|
|
"""
|
|
Convert a flexibly-typed reference to environment indices to an implied list of indices.
|
|
|
|
:param indices: (None,int,Iterable) refers to indices of envs.
|
|
:return: (list) the implied list of indices.
|
|
"""
|
|
if indices is None:
|
|
indices = range(self.num_envs)
|
|
elif isinstance(indices, int):
|
|
indices = [indices]
|
|
return indices
|
|
|
|
|
|
class VecEnvWrapper(VecEnv):
|
|
"""
|
|
Vectorized environment base class
|
|
|
|
:param venv: (VecEnv) the vectorized environment to wrap
|
|
:param observation_space: (Gym Space) the observation space (can be None to load from venv)
|
|
:param action_space: (Gym Space) the action space (can be None to load from venv)
|
|
"""
|
|
|
|
def __init__(self, venv, observation_space=None, action_space=None):
|
|
self.venv = venv
|
|
VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space,
|
|
action_space=action_space or venv.action_space)
|
|
self.class_attributes = dict(inspect.getmembers(self.__class__))
|
|
|
|
def step_async(self, actions):
|
|
self.venv.step_async(actions)
|
|
|
|
@abstractmethod
|
|
def reset(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def step_wait(self):
|
|
pass
|
|
|
|
def seed(self, seed=None):
|
|
return self.venv.seed(seed)
|
|
|
|
def close(self):
|
|
return self.venv.close()
|
|
|
|
def render(self, mode: str = 'human'):
|
|
return self.venv.render(mode=mode)
|
|
|
|
def get_images(self):
|
|
return self.venv.get_images()
|
|
|
|
def get_attr(self, attr_name, indices=None):
|
|
return self.venv.get_attr(attr_name, indices)
|
|
|
|
def set_attr(self, attr_name, value, indices=None):
|
|
return self.venv.set_attr(attr_name, value, indices)
|
|
|
|
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
|
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
|
|
|
|
def __getattr__(self, name):
|
|
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
|
|
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
|
|
which have unique attributes of interest.
|
|
"""
|
|
blocked_class = self.getattr_depth_check(name, already_found=False)
|
|
if blocked_class is not None:
|
|
own_class = f"{type(self).__module__}.{type(self).__name__}"
|
|
error_str = (f"Error: Recursive attribute lookup for {name} from {own_class} is "
|
|
"ambiguous and hides attribute from {blocked_class}")
|
|
raise AttributeError(error_str)
|
|
|
|
return self.getattr_recursive(name)
|
|
|
|
def _get_all_attributes(self):
|
|
"""Get all (inherited) instance and class attributes
|
|
|
|
:return: (dict<str, object>) all_attributes
|
|
"""
|
|
all_attributes = self.__dict__.copy()
|
|
all_attributes.update(self.class_attributes)
|
|
return all_attributes
|
|
|
|
def getattr_recursive(self, name):
|
|
"""Recursively check wrappers to find attribute.
|
|
|
|
:param name (str) name of attribute to look for
|
|
:return: (object) attribute
|
|
"""
|
|
all_attributes = self._get_all_attributes()
|
|
if name in all_attributes: # attribute is present in this wrapper
|
|
attr = getattr(self, name)
|
|
elif hasattr(self.venv, 'getattr_recursive'):
|
|
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
|
|
# to avoid a duplicate call to getattr_depth_check.
|
|
attr = self.venv.getattr_recursive(name)
|
|
else: # attribute not present, child is an unwrapped VecEnv
|
|
attr = getattr(self.venv, name)
|
|
|
|
return attr
|
|
|
|
def getattr_depth_check(self, name, already_found):
|
|
"""See base class.
|
|
|
|
:return: (str or None) name of module whose attribute is being shadowed, if any.
|
|
"""
|
|
all_attributes = self._get_all_attributes()
|
|
if name in all_attributes and already_found:
|
|
# this venv's attribute is being hidden because of a higher venv.
|
|
shadowed_wrapper_class = f"{type(self).__module__}.{type(self).__name__}"
|
|
elif name in all_attributes and not already_found:
|
|
# we have found the first reference to the attribute. Now check for duplicates.
|
|
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
|
|
else:
|
|
# this wrapper does not have the attribute. Keep searching.
|
|
shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)
|
|
|
|
return shadowed_wrapper_class
|
|
|
|
|
|
class CloudpickleWrapper(object):
|
|
def __init__(self, var):
|
|
"""
|
|
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
|
|
|
:param var: (Any) the variable you wish to wrap for pickling with cloudpickle
|
|
"""
|
|
self.var = var
|
|
|
|
def __getstate__(self):
|
|
return cloudpickle.dumps(self.var)
|
|
|
|
def __setstate__(self, obs):
|
|
self.var = pickle.loads(obs)
|