Add stable-baselines VecEnvs

This commit is contained in:
Antonin Raffin 2019-09-20 15:18:25 +02:00
parent cc4380eccd
commit 56053bc692
10 changed files with 1284 additions and 0 deletions

326
tests/test_vec_envs.py Normal file
View file

@ -0,0 +1,326 @@
import collections
import functools
import itertools
import multiprocessing
import pytest
import gym
import numpy as np
from torchy_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecFrameStack
N_ENVS = 3
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
VEC_ENV_WRAPPERS = [None, VecNormalize, VecFrameStack]
class CustomGymEnv(gym.Env):
def __init__(self, space):
"""
Custom gym environment for testing purposes
"""
self.action_space = space
self.observation_space = space
self.current_step = 0
self.ep_length = 4
def reset(self):
self.current_step = 0
self._choose_next_state()
return self.state
def step(self, action):
reward = 1
self._choose_next_state()
self.current_step += 1
done = self.current_step >= self.ep_length
return self.state, reward, done, {}
def _choose_next_state(self):
self.state = self.observation_space.sample()
def render(self, mode='human'):
pass
@staticmethod
def custom_method(dim_0=1, dim_1=1):
"""
Dummy method to test call to custom method
from VecEnv
:param dim_0: (int)
:param dim_1: (int)
:return: (np.ndarray)
"""
return np.ones((dim_0, dim_1))
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
"""Test access to methods/attributes of vectorized environments"""
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
if vec_env_wrapper is not None:
if vec_env_wrapper == VecFrameStack:
vec_env = vec_env_wrapper(vec_env, n_stack=2)
else:
vec_env = vec_env_wrapper(vec_env)
env_method_results = vec_env.env_method('custom_method', 1, indices=None, dim_1=2)
setattr_results = []
# Set current_step to an arbitrary value
for env_idx in range(N_ENVS):
setattr_results.append(vec_env.set_attr('current_step', env_idx, indices=env_idx))
# Retrieve the value for each environment
getattr_results = vec_env.get_attr('current_step')
assert len(env_method_results) == N_ENVS
assert len(setattr_results) == N_ENVS
assert len(getattr_results) == N_ENVS
for env_idx in range(N_ENVS):
assert (env_method_results[env_idx] == np.ones((1, 2))).all()
assert setattr_results[env_idx] is None
assert getattr_results[env_idx] == env_idx
# Call env_method on a subset of the VecEnv
env_method_subset = vec_env.env_method('custom_method', 1, indices=[0, 2], dim_1=3)
assert (env_method_subset[0] == np.ones((1, 3))).all()
assert (env_method_subset[1] == np.ones((1, 3))).all()
assert len(env_method_subset) == 2
# Test to change value for all the environments
setattr_result = vec_env.set_attr('current_step', 42, indices=None)
getattr_result = vec_env.get_attr('current_step')
assert setattr_result is None
assert getattr_result == [42 for _ in range(N_ENVS)]
# Additional tests for setattr that does not affect all the environments
vec_env.reset()
setattr_result = vec_env.set_attr('current_step', 12, indices=[0, 1])
getattr_result = vec_env.get_attr('current_step')
getattr_result_subset = vec_env.get_attr('current_step', indices=[0, 1])
assert setattr_result is None
assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)]
assert getattr_result_subset == [12, 12]
assert vec_env.get_attr('current_step', indices=[0, 2]) == [12, 0]
vec_env.reset()
# Change value only for first and last environment
setattr_result = vec_env.set_attr('current_step', 12, indices=[0, -1])
getattr_result = vec_env.get_attr('current_step')
assert setattr_result is None
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]
assert vec_env.get_attr('current_step', indices=[-1]) == [12]
vec_env.close()
class StepEnv(gym.Env):
def __init__(self, max_steps):
"""Gym environment for testing that terminal observation is inserted
correctly."""
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]),
dtype='int')
self.max_steps = max_steps
self.current_step = 0
def reset(self):
self.current_step = 0
return np.array([self.current_step], dtype='int')
def step(self, action):
prev_step = self.current_step
self.current_step += 1
done = self.current_step >= self.max_steps
return np.array([prev_step], dtype='int'), 0.0, done, {}
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
"""Test that 'terminal_observation' gets added to info dict upon
termination."""
step_nums = [i + 5 for i in range(N_ENVS)]
vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums])
if vec_env_wrapper is not None:
if vec_env_wrapper == VecFrameStack:
vec_env = vec_env_wrapper(vec_env, n_stack=2)
else:
vec_env = vec_env_wrapper(vec_env)
zero_acts = np.zeros((N_ENVS,), dtype='int')
prev_obs_b = vec_env.reset()
for step_num in range(1, max(step_nums) + 1):
obs_b, _, done_b, info_b = vec_env.step(zero_acts)
assert len(obs_b) == N_ENVS
assert len(done_b) == N_ENVS
assert len(info_b) == N_ENVS
env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums)
for prev_obs, obs, done, info, final_step_num in env_iter:
assert done == (step_num == final_step_num)
if not done:
assert 'terminal_observation' not in info
else:
terminal_obs = info['terminal_observation']
# do some rough ordering checks that should work for all
# wrappers, including VecNormalize
assert np.all(prev_obs < terminal_obs)
assert np.all(obs < prev_obs)
if not isinstance(vec_env, VecNormalize):
# more precise tests that we can't do with VecNormalize
# (which changes observation values)
assert np.all(prev_obs + 1 == terminal_obs)
assert np.all(obs == 0)
prev_obs_b = obs_b
vec_env.close()
SPACES = collections.OrderedDict([
('discrete', gym.spaces.Discrete(2)),
('multidiscrete', gym.spaces.MultiDiscrete([2, 3])),
('multibinary', gym.spaces.MultiBinary(3)),
('continuous', gym.spaces.Box(low=np.zeros(2), high=np.ones(2))),
])
def check_vecenv_spaces(vec_env_class, space, obs_assert):
"""Helper method to check observation spaces in vectorized environments."""
def make_env():
return CustomGymEnv(space)
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
obs = vec_env.reset()
obs_assert(obs)
dones = [False] * N_ENVS
while not any(dones):
actions = [vec_env.action_space.sample() for _ in range(N_ENVS)]
obs, _rews, dones, _infos = vec_env.step(actions)
obs_assert(obs)
vec_env.close()
def check_vecenv_obs(obs, space):
"""Helper method to check observations from multiple environments each belong to
the appropriate observation space."""
assert obs.shape[0] == N_ENVS
for value in obs:
assert space.contains(value)
@pytest.mark.parametrize('vec_env_class,space', itertools.product(VEC_ENV_CLASSES, SPACES.values()))
def test_vecenv_single_space(vec_env_class, space):
def obs_assert(obs):
return check_vecenv_obs(obs, space)
check_vecenv_spaces(vec_env_class, space, obs_assert)
class _UnorderedDictSpace(gym.spaces.Dict):
"""Like DictSpace, but returns an unordered dict when sampling."""
def sample(self):
return dict(super().sample())
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
def test_vecenv_dict_spaces(vec_env_class):
"""Test dictionary observation spaces with vectorized environments."""
space = gym.spaces.Dict(SPACES)
def obs_assert(obs):
assert isinstance(obs, collections.OrderedDict)
assert obs.keys() == space.spaces.keys()
for key, values in obs.items():
check_vecenv_obs(values, space.spaces[key])
check_vecenv_spaces(vec_env_class, space, obs_assert)
unordered_space = _UnorderedDictSpace(SPACES)
# Check that vec_env_class can accept unordered dict observations (and convert to OrderedDict)
check_vecenv_spaces(vec_env_class, unordered_space, obs_assert)
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
def test_vecenv_tuple_spaces(vec_env_class):
"""Test tuple observation spaces with vectorized environments."""
space = gym.spaces.Tuple(tuple(SPACES.values()))
def obs_assert(obs):
assert isinstance(obs, tuple)
assert len(obs) == len(space.spaces)
for values, inner_space in zip(obs, space.spaces):
check_vecenv_obs(values, inner_space)
return check_vecenv_spaces(vec_env_class, space, obs_assert)
def test_subproc_start_method():
start_methods = [None]
# Only test thread-safe methods. Others may deadlock tests! (gh/428)
safe_methods = {'forkserver', 'spawn'}
available_methods = multiprocessing.get_all_start_methods()
start_methods += list(safe_methods.intersection(available_methods))
space = gym.spaces.Discrete(2)
def obs_assert(obs):
return check_vecenv_obs(obs, space)
for start_method in start_methods:
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)
check_vecenv_spaces(vec_env_class, space, obs_assert)
with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"):
vec_env_class = functools.partial(SubprocVecEnv, start_method='illegal_method')
check_vecenv_spaces(vec_env_class, space, obs_assert)
class CustomWrapperA(VecNormalize):
def __init__(self, venv):
VecNormalize.__init__(self, venv)
self.var_a = 'a'
class CustomWrapperB(VecNormalize):
def __init__(self, venv):
VecNormalize.__init__(self, venv)
self.var_b = 'b'
def func_b(self):
return self.var_b
def name_test(self):
return self.__class__
class CustomWrapperBB(CustomWrapperB):
def __init__(self, venv):
CustomWrapperB.__init__(self, venv)
self.var_bb = 'bb'
def test_vecenv_wrapper_getattr():
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
assert wrapped.var_a == 'a'
assert wrapped.var_b == 'b'
assert wrapped.var_bb == 'bb'
assert wrapped.func_b() == 'b'
assert wrapped.name_test() == CustomWrapperBB
double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
dummy = double_wrapped.var_a # should not raise as it is directly defined here
with pytest.raises(AttributeError): # should raise due to ambiguity
dummy = double_wrapped.var_b
with pytest.raises(AttributeError): # should raise as does not exist
dummy = double_wrapped.nonexistent_attribute
del dummy # keep linter happy

View file

@ -0,0 +1,41 @@
import gym
import numpy as np
from torchy_baselines.common.running_mean_std import RunningMeanStd
from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from torchy_baselines.common.vec_env.vec_normalize import VecNormalize
ENV_ID = 'Pendulum-v0'
def test_runningmeanstd():
"""Test RunningMeanStd object"""
for (x_1, x_2, x_3) in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2))]:
rms = RunningMeanStd(epsilon=0.0, shape=x_1.shape[1:])
x_cat = np.concatenate([x_1, x_2, x_3], axis=0)
moments_1 = [x_cat.mean(axis=0), x_cat.var(axis=0)]
rms.update(x_1)
rms.update(x_2)
rms.update(x_3)
moments_2 = [rms.mean, rms.var]
assert np.allclose(moments_1, moments_2)
def test_vec_env():
"""Test VecNormalize Object"""
def make_env():
return gym.make(ENV_ID)
env = DummyVecEnv([make_env])
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
_, done = env.reset(), [False]
obs = None
while not done[0]:
actions = [env.action_space.sample()]
obs, _, done, _ = env.step(actions)
assert np.max(obs) <= 10

View file

@ -0,0 +1,37 @@
import numpy as np
class RunningMeanStd(object):
def __init__(self, epsilon=1e-4, shape=()):
"""
calulates the running mean and std of a data stream
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
:param epsilon: (float) helps with arithmetic issues
:param shape: (tuple) the shape of the data stream's output
"""
self.mean = np.zeros(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.count = epsilon
def update(self, arr):
batch_mean = np.mean(arr, axis=0)
batch_var = np.var(arr, axis=0)
batch_count = arr.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
delta = batch_mean - self.mean
tot_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / tot_count
m_a = self.var * self.count
m_b = batch_var * batch_count
m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
new_var = m_2 / (self.count + batch_count)
new_count = batch_count + self.count
self.mean = new_mean
self.var = new_var
self.count = new_count

View file

@ -0,0 +1,7 @@
# flake8: noqa F401
from torchy_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, \
CloudpickleWrapper
from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack
from torchy_baselines.common.vec_env.vec_normalize import VecNormalize

View file

@ -0,0 +1,301 @@
from abc import ABCMeta, abstractmethod
import inspect
import pickle
import cloudpickle
class AlreadySteppingError(Exception):
"""
Raised when an asynchronous step is running while
step_async() is called again.
"""
def __init__(self):
msg = 'already running an async step'
Exception.__init__(self, msg)
class NotSteppingError(Exception):
"""
Raised when an asynchronous step is not running but
step_wait() is called.
"""
def __init__(self):
msg = 'not running an async step'
Exception.__init__(self, msg)
class VecEnv(object):
"""
An abstract asynchronous, vectorized environment.
:param num_envs: (int) the number of environments
:param observation_space: (Gym Space) the observation space
:param action_space: (Gym Space) the action space
"""
metadata = {
'render.modes': ['human', 'rgb_array']
}
__metaclass__ = ABCMeta
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
@abstractmethod
def reset(self):
"""
Reset all the environments and return an array of
observations, or a tuple of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
:return: ([int] or [float]) observation
"""
pass
@abstractmethod
def step_async(self, actions):
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
pass
@abstractmethod
def step_wait(self):
"""
Wait for the step taken with step_async().
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
"""
pass
@abstractmethod
def close(self):
"""
Clean up the environment's resources.
"""
pass
@abstractmethod
def get_attr(self, attr_name, indices=None):
"""
Return attribute from vectorized environment.
:param attr_name: (str) The name of the attribute whose value to return
:param indices: (list,int) Indices of envs to get attribute from
:return: (list) List of values of 'attr_name' in all environments
"""
pass
@abstractmethod
def set_attr(self, attr_name, value, indices=None):
"""
Set attribute inside vectorized environments.
:param attr_name: (str) The name of attribute to assign new value
:param value: (obj) Value to assign to `attr_name`
:param indices: (list,int) Indices of envs to assign value
:return: (NoneType)
"""
pass
@abstractmethod
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
"""
Call instance methods of vectorized environments.
:param method_name: (str) The name of the environment method to invoke.
:param indices: (list,int) Indices of envs whose method to call
:param method_args: (tuple) Any positional arguments to provide in the call
:param method_kwargs: (dict) Any keyword arguments to provide in the call
:return: (list) List of items returned by the environment's method call
"""
pass
def step(self, actions):
"""
Step the environments with the given action
:param actions: ([int] or [float]) the action
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
"""
self.step_async(actions)
return self.step_wait()
def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
def render(self, *args, **kwargs):
"""
Gym environment rendering
:param mode: (str) the rendering type
"""
raise NotImplementedError()
@property
def unwrapped(self):
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
def getattr_depth_check(self, name, already_found):
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
:param name: (str) name of attribute to check for
:param already_found: (bool) whether this attribute has already been found in a wrapper
:return: (str or None) name of module whose attribute is being shadowed, if any.
"""
if hasattr(self, name) and already_found:
return "{0}.{1}".format(type(self).__module__, type(self).__name__)
else:
return None
def _get_indices(self, indices):
"""
Convert a flexibly-typed reference to environment indices to an implied list of indices.
:param indices: (None,int,Iterable) refers to indices of envs.
:return: (list) the implied list of indices.
"""
if indices is None:
indices = range(self.num_envs)
elif isinstance(indices, int):
indices = [indices]
return indices
class VecEnvWrapper(VecEnv):
"""
Vectorized environment base class
:param venv: (VecEnv) the vectorized environment to wrap
:param observation_space: (Gym Space) the observation space (can be None to load from venv)
:param action_space: (Gym Space) the action space (can be None to load from venv)
"""
def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv
VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space)
self.class_attributes = dict(inspect.getmembers(self.__class__))
def step_async(self, actions):
self.venv.step_async(actions)
@abstractmethod
def reset(self):
pass
@abstractmethod
def step_wait(self):
pass
def close(self):
return self.venv.close()
def render(self, *args, **kwargs):
return self.venv.render(*args, **kwargs)
def get_images(self):
return self.venv.get_images()
def get_attr(self, attr_name, indices=None):
return self.venv.get_attr(attr_name, indices)
def set_attr(self, attr_name, value, indices=None):
return self.venv.set_attr(attr_name, value, indices)
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
def __getattr__(self, name):
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
which have unique attributes of interest.
"""
blocked_class = self.getattr_depth_check(name, already_found=False)
if blocked_class is not None:
own_class = "{0}.{1}".format(type(self).__module__, type(self).__name__)
format_str = ("Error: Recursive attribute lookup for {0} from {1} is "
"ambiguous and hides attribute from {2}")
raise AttributeError(format_str.format(name, own_class, blocked_class))
return self.getattr_recursive(name)
def _get_all_attributes(self):
"""Get all (inherited) instance and class attributes
:return: (dict<str, object>) all_attributes
"""
all_attributes = self.__dict__.copy()
all_attributes.update(self.class_attributes)
return all_attributes
def getattr_recursive(self, name):
"""Recursively check wrappers to find attribute.
:param name (str) name of attribute to look for
:return: (object) attribute
"""
all_attributes = self._get_all_attributes()
if name in all_attributes: # attribute is present in this wrapper
attr = getattr(self, name)
elif hasattr(self.venv, 'getattr_recursive'):
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
# to avoid a duplicate call to getattr_depth_check.
attr = self.venv.getattr_recursive(name)
else: # attribute not present, child is an unwrapped VecEnv
attr = getattr(self.venv, name)
return attr
def getattr_depth_check(self, name, already_found):
"""See base class.
:return: (str or None) name of module whose attribute is being shadowed, if any.
"""
all_attributes = self._get_all_attributes()
if name in all_attributes and already_found:
# this venv's attribute is being hidden because of a higher venv.
shadowed_wrapper_class = "{0}.{1}".format(type(self).__module__, type(self).__name__)
elif name in all_attributes and not already_found:
# we have found the first reference to the attribute. Now check for duplicates.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
else:
# this wrapper does not have the attribute. Keep searching.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)
return shadowed_wrapper_class
class CloudpickleWrapper(object):
def __init__(self, var):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
:param var: (Any) the variable you wish to wrap for pickling with cloudpickle
"""
self.var = var
def __getstate__(self):
return cloudpickle.dumps(self.var)
def __setstate__(self, obs):
self.var = pickle.loads(obs)

View file

@ -0,0 +1,107 @@
from collections import OrderedDict
import numpy as np
from torchy_baselines.common.vec_env import VecEnv
from torchy_baselines.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
class DummyVecEnv(VecEnv):
"""
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
Python process. This is useful for computationally simple environment such as ``cartpole-v1``, as the overhead of
multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that
require a vectorized environment, but that you want a single environments to train with.
:param env_fns: ([Gym Environment]) the list of environments to vectorize
"""
def __init__(self, env_fns):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)
self.buf_obs = OrderedDict([
(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]))
for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)]
self.actions = None
self.metadata = env.metadata
def step_async(self, actions):
self.actions = actions
def step_wait(self):
for env_idx in range(self.num_envs):
obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\
self.envs[env_idx].step(self.actions[env_idx])
if self.buf_dones[env_idx]:
# save final observation where user can get it, then reset
self.buf_infos[env_idx]['terminal_observation'] = obs
obs = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
self.buf_infos.copy())
def reset(self):
for env_idx in range(self.num_envs):
obs = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
return self._obs_from_buf()
def seed(self, seed, indices=None):
"""
:param seed: (int or [int])
:param indices: ([int])
"""
indices = self._get_indices(indices)
if not hasattr(seed, 'len'):
seed = [seed] * len(indices)
assert len(seed) == len(indices)
return [self.envs[i].seed(seed[i]) for i in indices]
def close(self):
for env in self.envs:
env.close()
def get_images(self):
return [env.render(mode='rgb_array') for env in self.envs]
def render(self, *args, **kwargs):
if self.num_envs == 1:
return self.envs[0].render(*args, **kwargs)
else:
return super().render(*args, **kwargs)
def _save_obs(self, env_idx, obs):
for key in self.keys:
if key is None:
self.buf_obs[key][env_idx] = obs
else:
self.buf_obs[key][env_idx] = obs[key]
def _obs_from_buf(self):
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
def get_attr(self, attr_name, indices=None):
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, attr_name) for env_i in target_envs]
def set_attr(self, attr_name, value, indices=None):
"""Set attribute inside vectorized environments (see base class)."""
target_envs = self._get_target_envs(indices)
for env_i in target_envs:
setattr(env_i, attr_name, value)
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
def _get_target_envs(self, indices):
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]

View file

@ -0,0 +1,232 @@
import multiprocessing
from collections import OrderedDict
import gym
import numpy as np
from torchy_baselines.common.vec_env import VecEnv, CloudpickleWrapper
def _worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.var()
while True:
try:
cmd, data = remote.recv()
if cmd == 'step':
observation, reward, done, info = env.step(data)
if done:
# save final observation where user can get it, then reset
info['terminal_observation'] = observation
observation = env.reset()
remote.send((observation, reward, done, info))
elif cmd == 'reset':
observation = env.reset()
remote.send(observation)
elif cmd == 'render':
remote.send(env.render(*data[0], **data[1]))
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'env_method':
method = getattr(env, data[0])
remote.send(method(*data[1], **data[2]))
elif cmd == 'get_attr':
remote.send(getattr(env, data))
elif cmd == 'set_attr':
remote.send(setattr(env, data[0], data[1]))
else:
raise NotImplementedError
except EOFError:
break
def tile_images(img_nhwc):
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
:param img_nhwc: (list) list or array of images, ndim=4 once turned into array. img nhwc
n = batch index, h = height, w = width, c = channel
:return: (numpy float) img_HWc, ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
n_images, height, width, n_channels = img_nhwc.shape
# new_height was named H before
new_height = int(np.ceil(np.sqrt(n_images)))
# new_width was named W before
new_width = int(np.ceil(float(n_images) / new_height))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
# img_HWhwc
out_image = img_nhwc.reshape(new_height, new_width, height, width, n_channels)
# img_HhWwc
out_image = out_image.transpose(0, 2, 1, 3, 4)
# img_Hh_Ww_c
out_image = out_image.reshape(new_height * height, new_width * width, n_channels)
return out_image
class SubprocVecEnv(VecEnv):
"""
Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
process, allowing significant speed up when the environment is computationally complex.
For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
number of logical cores on your CPU.
.. warning::
Only 'forkserver' and 'spawn' start methods are thread-safe,
which is important when TensorFlow sessions or other non thread-safe
libraries are used in the parent (see issue #217). However, compared to
'fork' they incur a small start-up cost and have restrictions on
global variables. With those methods, users must wrap the code in an
``if __name__ == "__main__":`` block.
For more information, see the multiprocessing documentation.
:param env_fns: ([Gym Environment]) Environments to run in subprocesses
:param start_method: (str) method used to start the subprocesses.
Must be one of the methods returned by multiprocessing.get_all_start_methods().
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
"""
def __init__(self, env_fns, start_method=None):
self.waiting = False
self.closed = False
n_envs = len(env_fns)
if start_method is None:
# Fork is not a thread safe method (see issue #217)
# but is more user friendly (does not require to wrap the code in
# a `if __name__ == "__main__":`)
forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods()
start_method = 'forkserver' if forkserver_available else 'spawn'
ctx = multiprocessing.get_context(start_method)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
self.processes = []
for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
args = (work_remote, remote, CloudpickleWrapper(env_fn))
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True)
process.start()
self.processes.append(process)
work_remote.close()
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
obs = [remote.recv() for remote in self.remotes]
return _flatten_obs(obs, self.observation_space)
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for process in self.processes:
process.join()
self.closed = True
def render(self, mode='human', *args, **kwargs):
for pipe in self.remotes:
# gather images from subprocesses
# `mode` will be taken into account later
pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs})))
imgs = [pipe.recv() for pipe in self.remotes]
# Create a big image by tiling images from subprocesses
bigimg = tile_images(imgs)
if mode == 'human':
import cv2
cv2.imshow('vecenv', bigimg[:, :, ::-1])
cv2.waitKey(1)
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
def get_images(self):
for pipe in self.remotes:
pipe.send(('render', {"mode": 'rgb_array'}))
imgs = [pipe.recv() for pipe in self.remotes]
return imgs
def get_attr(self, attr_name, indices=None):
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(('get_attr', attr_name))
return [remote.recv() for remote in target_remotes]
def set_attr(self, attr_name, value, indices=None):
"""Set attribute inside vectorized environments (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(('set_attr', (attr_name, value)))
for remote in target_remotes:
remote.recv()
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
"""Call instance methods of vectorized environments."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(('env_method', (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in target_remotes]
def _get_target_remotes(self, indices):
"""
Get the connection object needed to communicate with the wanted
envs that are in subprocesses.
:param indices: (None,int,Iterable) refers to indices of envs.
:return: ([multiprocessing.Connection]) Connection object to communicate between processes.
"""
indices = self._get_indices(indices)
return [self.remotes[i] for i in indices]
def _flatten_obs(obs, space):
"""
Flatten observations, depending on the observation space.
:param obs: (list<X> or tuple<X> where X is dict<ndarray>, tuple<ndarray> or ndarray) observations.
A list or tuple of observations, one per environment.
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
:return (OrderedDict<ndarray>, tuple<ndarray> or ndarray) flattened observations.
A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
Each NumPy array has the environment index as its first axis.
"""
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
assert len(obs) > 0, "need observations from at least one environment"
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len)))
else:
return np.stack(obs)

View file

@ -0,0 +1,73 @@
"""
Helpers for dealing with vectorized environments.
"""
from collections import OrderedDict
import gym
import numpy as np
def copy_obs_dict(obs):
"""
Deep-copy a dict of numpy arrays.
:param obs: (OrderedDict<ndarray>): a dict of numpy arrays.
:return (OrderedDict<ndarray>) a dict of copied numpy arrays.
"""
assert isinstance(obs, OrderedDict), "unexpected type for observations '{}'".format(type(obs))
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
def dict_to_obs(space, obs_dict):
"""
Convert an internal representation raw_obs into the appropriate type
specified by space.
:param space: (gym.spaces.Space) an observation space.
:param obs_dict: (OrderedDict<ndarray>) a dict of numpy arrays.
:return (ndarray, tuple<ndarray> or dict<ndarray>): returns an observation
of the same type as space. If space is Dict, function is identity;
if space is Tuple, converts dict to Tuple; otherwise, space is
unstructured and returns the value raw_obs[None].
"""
if isinstance(space, gym.spaces.Dict):
return obs_dict
elif isinstance(space, gym.spaces.Tuple):
assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space"
return tuple((obs_dict[i] for i in range(len(space.spaces))))
else:
assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
return obs_dict[None]
def obs_space_info(obs_space):
"""
Get dict-structured information about a gym.Space.
Dict spaces are represented directly by their dict of subspaces.
Tuple spaces are converted into a dict with keys indexing into the tuple.
Unstructured spaces are represented by {None: obs_space}.
:param obs_space: (gym.spaces.Space) an observation space
:return (tuple) A tuple (keys, shapes, dtypes):
keys: a list of dict keys.
shapes: a dict mapping keys to shapes.
dtypes: a dict mapping keys to dtypes.
"""
if isinstance(obs_space, gym.spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
subspaces = obs_space.spaces
elif isinstance(obs_space, gym.spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
else:
assert not hasattr(obs_space, 'spaces'), "Unsupported structured space '{}'".format(type(obs_space))
subspaces = {None: obs_space}
keys = []
shapes = {}
dtypes = {}
for key, box in subspaces.items():
keys.append(key)
shapes[key] = box.shape
dtypes[key] = box.dtype
return keys, shapes, dtypes

View file

@ -0,0 +1,55 @@
import warnings
import numpy as np
from gym import spaces
from torchy_baselines.common.vec_env import VecEnvWrapper
class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment
:param venv: (VecEnv) the vectorized environment to wrap
:param n_stack: (int) Number of frames to stack
"""
def __init__(self, venv, n_stack):
self.venv = venv
self.n_stack = n_stack
wrapped_obs_space = venv.observation_space
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
def step_wait(self):
observations, rewards, dones, infos = self.venv.step_wait()
last_ax_size = observations.shape[-1]
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
for i, done in enumerate(dones):
if done:
if 'terminal_observation' in infos[i]:
old_terminal = infos[i]['terminal_observation']
new_terminal = np.concatenate(
(self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
infos[i]['terminal_observation'] = new_terminal
else:
warnings.warn(
"VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stackedobs[i] = 0
self.stackedobs[..., -observations.shape[-1]:] = observations
return self.stackedobs, rewards, dones, infos
def reset(self):
"""
Reset all environments
"""
obs = self.venv.reset()
self.stackedobs[...] = 0
self.stackedobs[..., -obs.shape[-1]:] = obs
return self.stackedobs
def close(self):
self.venv.close()

View file

@ -0,0 +1,105 @@
import pickle
import numpy as np
from torchy_baselines.common.vec_env import VecEnvWrapper
from torchy_baselines.common.running_mean_std import RunningMeanStd
class VecNormalize(VecEnvWrapper):
"""
A moving average, normalizing wrapper for vectorized environment.
has support for saving/loading moving average,
:param venv: (VecEnv) the vectorized environment to wrap
:param training: (bool) Whether to update or not the moving average
:param norm_obs: (bool) Whether to normalize observation or not (default: True)
:param norm_reward: (bool) Whether to normalize rewards or not (default: True)
:param clip_obs: (float) Max absolute value for observation
:param clip_reward: (float) Max value absolute for discounted reward
:param gamma: (float) discount factor
:param epsilon: (float) To avoid division by zero
"""
def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,
clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8):
VecEnvWrapper.__init__(self, venv)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.ret_rms = RunningMeanStd(shape=())
self.clip_obs = clip_obs
self.clip_reward = clip_reward
# Returns: discounted rewards
self.ret = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon
self.training = training
self.norm_obs = norm_obs
self.norm_reward = norm_reward
self.old_obs = np.array([])
def step_wait(self):
"""
Apply sequence of actions to sequence of environments
actions -> (observations, rewards, news)
where 'news' is a boolean vector indicating whether each element is new.
"""
obs, rews, news, infos = self.venv.step_wait()
self.ret = self.ret * self.gamma + rews
self.old_obs = obs
obs = self._normalize_observation(obs)
if self.norm_reward:
if self.training:
self.ret_rms.update(self.ret)
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
self.ret[news] = 0
return obs, rews, news, infos
def _normalize_observation(self, obs):
"""
:param obs: (numpy tensor)
"""
if self.norm_obs:
if self.training:
self.obs_rms.update(obs)
obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs,
self.clip_obs)
return obs
else:
return obs
def get_original_obs(self):
"""
returns the unnormalized observation
:return: (numpy float)
"""
return self.old_obs
def reset(self):
"""
Reset all environments
"""
obs = self.venv.reset()
if len(np.array(obs).shape) == 1: # for when num_cpu is 1
self.old_obs = [obs]
else:
self.old_obs = obs
self.ret = np.zeros(self.num_envs)
return self._normalize_observation(obs)
def save_running_average(self, path):
"""
:param path: (str) path to log dir
"""
for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']):
with open("{}/{}.pkl".format(path, name), 'wb') as file_handler:
pickle.dump(rms, file_handler)
def load_running_average(self, path):
"""
:param path: (str) path to log dir
"""
for name in ['obs_rms', 'ret_rms']:
with open("{}/{}.pkl".format(path, name), 'rb') as file_handler:
setattr(self, name, pickle.load(file_handler))