diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py new file mode 100644 index 0000000..2147e78 --- /dev/null +++ b/tests/test_vec_envs.py @@ -0,0 +1,326 @@ +import collections +import functools +import itertools +import multiprocessing + +import pytest +import gym +import numpy as np + +from torchy_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecFrameStack + +N_ENVS = 3 +VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv] +VEC_ENV_WRAPPERS = [None, VecNormalize, VecFrameStack] + + +class CustomGymEnv(gym.Env): + def __init__(self, space): + """ + Custom gym environment for testing purposes + """ + self.action_space = space + self.observation_space = space + self.current_step = 0 + self.ep_length = 4 + + def reset(self): + self.current_step = 0 + self._choose_next_state() + return self.state + + def step(self, action): + reward = 1 + self._choose_next_state() + self.current_step += 1 + done = self.current_step >= self.ep_length + return self.state, reward, done, {} + + def _choose_next_state(self): + self.state = self.observation_space.sample() + + def render(self, mode='human'): + pass + + @staticmethod + def custom_method(dim_0=1, dim_1=1): + """ + Dummy method to test call to custom method + from VecEnv + + :param dim_0: (int) + :param dim_1: (int) + :return: (np.ndarray) + """ + return np.ones((dim_0, dim_1)) + + +@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES) +@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS) +def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper): + """Test access to methods/attributes of vectorized environments""" + def make_env(): + return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) + + if vec_env_wrapper is not None: + if vec_env_wrapper == VecFrameStack: + vec_env = vec_env_wrapper(vec_env, n_stack=2) + else: + vec_env = vec_env_wrapper(vec_env) + + env_method_results = vec_env.env_method('custom_method', 1, indices=None, dim_1=2) + setattr_results = [] + # Set current_step to an arbitrary value + for env_idx in range(N_ENVS): + setattr_results.append(vec_env.set_attr('current_step', env_idx, indices=env_idx)) + # Retrieve the value for each environment + getattr_results = vec_env.get_attr('current_step') + + assert len(env_method_results) == N_ENVS + assert len(setattr_results) == N_ENVS + assert len(getattr_results) == N_ENVS + + for env_idx in range(N_ENVS): + assert (env_method_results[env_idx] == np.ones((1, 2))).all() + assert setattr_results[env_idx] is None + assert getattr_results[env_idx] == env_idx + + # Call env_method on a subset of the VecEnv + env_method_subset = vec_env.env_method('custom_method', 1, indices=[0, 2], dim_1=3) + assert (env_method_subset[0] == np.ones((1, 3))).all() + assert (env_method_subset[1] == np.ones((1, 3))).all() + assert len(env_method_subset) == 2 + + + # Test to change value for all the environments + setattr_result = vec_env.set_attr('current_step', 42, indices=None) + getattr_result = vec_env.get_attr('current_step') + assert setattr_result is None + assert getattr_result == [42 for _ in range(N_ENVS)] + + # Additional tests for setattr that does not affect all the environments + vec_env.reset() + setattr_result = vec_env.set_attr('current_step', 12, indices=[0, 1]) + getattr_result = vec_env.get_attr('current_step') + getattr_result_subset = vec_env.get_attr('current_step', indices=[0, 1]) + assert setattr_result is None + assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)] + assert getattr_result_subset == [12, 12] + assert vec_env.get_attr('current_step', indices=[0, 2]) == [12, 0] + + vec_env.reset() + # Change value only for first and last environment + setattr_result = vec_env.set_attr('current_step', 12, indices=[0, -1]) + getattr_result = vec_env.get_attr('current_step') + assert setattr_result is None + assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12] + assert vec_env.get_attr('current_step', indices=[-1]) == [12] + + vec_env.close() + + +class StepEnv(gym.Env): + def __init__(self, max_steps): + """Gym environment for testing that terminal observation is inserted + correctly.""" + self.action_space = gym.spaces.Discrete(2) + self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]), + dtype='int') + self.max_steps = max_steps + self.current_step = 0 + + def reset(self): + self.current_step = 0 + return np.array([self.current_step], dtype='int') + + def step(self, action): + prev_step = self.current_step + self.current_step += 1 + done = self.current_step >= self.max_steps + return np.array([prev_step], dtype='int'), 0.0, done, {} + + +@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES) +@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS) +def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper): + """Test that 'terminal_observation' gets added to info dict upon + termination.""" + step_nums = [i + 5 for i in range(N_ENVS)] + vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums]) + + if vec_env_wrapper is not None: + if vec_env_wrapper == VecFrameStack: + vec_env = vec_env_wrapper(vec_env, n_stack=2) + else: + vec_env = vec_env_wrapper(vec_env) + + zero_acts = np.zeros((N_ENVS,), dtype='int') + prev_obs_b = vec_env.reset() + for step_num in range(1, max(step_nums) + 1): + obs_b, _, done_b, info_b = vec_env.step(zero_acts) + assert len(obs_b) == N_ENVS + assert len(done_b) == N_ENVS + assert len(info_b) == N_ENVS + env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums) + for prev_obs, obs, done, info, final_step_num in env_iter: + assert done == (step_num == final_step_num) + if not done: + assert 'terminal_observation' not in info + else: + terminal_obs = info['terminal_observation'] + + # do some rough ordering checks that should work for all + # wrappers, including VecNormalize + assert np.all(prev_obs < terminal_obs) + assert np.all(obs < prev_obs) + + if not isinstance(vec_env, VecNormalize): + # more precise tests that we can't do with VecNormalize + # (which changes observation values) + assert np.all(prev_obs + 1 == terminal_obs) + assert np.all(obs == 0) + + prev_obs_b = obs_b + + vec_env.close() + + +SPACES = collections.OrderedDict([ + ('discrete', gym.spaces.Discrete(2)), + ('multidiscrete', gym.spaces.MultiDiscrete([2, 3])), + ('multibinary', gym.spaces.MultiBinary(3)), + ('continuous', gym.spaces.Box(low=np.zeros(2), high=np.ones(2))), +]) + +def check_vecenv_spaces(vec_env_class, space, obs_assert): + """Helper method to check observation spaces in vectorized environments.""" + def make_env(): + return CustomGymEnv(space) + + vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) + obs = vec_env.reset() + obs_assert(obs) + + dones = [False] * N_ENVS + while not any(dones): + actions = [vec_env.action_space.sample() for _ in range(N_ENVS)] + obs, _rews, dones, _infos = vec_env.step(actions) + obs_assert(obs) + vec_env.close() + + +def check_vecenv_obs(obs, space): + """Helper method to check observations from multiple environments each belong to + the appropriate observation space.""" + assert obs.shape[0] == N_ENVS + for value in obs: + assert space.contains(value) + + +@pytest.mark.parametrize('vec_env_class,space', itertools.product(VEC_ENV_CLASSES, SPACES.values())) +def test_vecenv_single_space(vec_env_class, space): + def obs_assert(obs): + return check_vecenv_obs(obs, space) + + check_vecenv_spaces(vec_env_class, space, obs_assert) + + +class _UnorderedDictSpace(gym.spaces.Dict): + """Like DictSpace, but returns an unordered dict when sampling.""" + def sample(self): + return dict(super().sample()) + + +@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES) +def test_vecenv_dict_spaces(vec_env_class): + """Test dictionary observation spaces with vectorized environments.""" + space = gym.spaces.Dict(SPACES) + + def obs_assert(obs): + assert isinstance(obs, collections.OrderedDict) + assert obs.keys() == space.spaces.keys() + for key, values in obs.items(): + check_vecenv_obs(values, space.spaces[key]) + + check_vecenv_spaces(vec_env_class, space, obs_assert) + + unordered_space = _UnorderedDictSpace(SPACES) + # Check that vec_env_class can accept unordered dict observations (and convert to OrderedDict) + check_vecenv_spaces(vec_env_class, unordered_space, obs_assert) + + +@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES) +def test_vecenv_tuple_spaces(vec_env_class): + """Test tuple observation spaces with vectorized environments.""" + space = gym.spaces.Tuple(tuple(SPACES.values())) + + def obs_assert(obs): + assert isinstance(obs, tuple) + assert len(obs) == len(space.spaces) + for values, inner_space in zip(obs, space.spaces): + check_vecenv_obs(values, inner_space) + + return check_vecenv_spaces(vec_env_class, space, obs_assert) + + +def test_subproc_start_method(): + start_methods = [None] + # Only test thread-safe methods. Others may deadlock tests! (gh/428) + safe_methods = {'forkserver', 'spawn'} + available_methods = multiprocessing.get_all_start_methods() + start_methods += list(safe_methods.intersection(available_methods)) + space = gym.spaces.Discrete(2) + + def obs_assert(obs): + return check_vecenv_obs(obs, space) + + for start_method in start_methods: + vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method) + check_vecenv_spaces(vec_env_class, space, obs_assert) + + with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"): + vec_env_class = functools.partial(SubprocVecEnv, start_method='illegal_method') + check_vecenv_spaces(vec_env_class, space, obs_assert) + + +class CustomWrapperA(VecNormalize): + def __init__(self, venv): + VecNormalize.__init__(self, venv) + self.var_a = 'a' + + +class CustomWrapperB(VecNormalize): + def __init__(self, venv): + VecNormalize.__init__(self, venv) + self.var_b = 'b' + + def func_b(self): + return self.var_b + + def name_test(self): + return self.__class__ + +class CustomWrapperBB(CustomWrapperB): + def __init__(self, venv): + CustomWrapperB.__init__(self, venv) + self.var_bb = 'bb' + +def test_vecenv_wrapper_getattr(): + def make_env(): + return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)]) + wrapped = CustomWrapperA(CustomWrapperBB(vec_env)) + assert wrapped.var_a == 'a' + assert wrapped.var_b == 'b' + assert wrapped.var_bb == 'bb' + assert wrapped.func_b() == 'b' + assert wrapped.name_test() == CustomWrapperBB + + double_wrapped = CustomWrapperA(CustomWrapperB(wrapped)) + dummy = double_wrapped.var_a # should not raise as it is directly defined here + with pytest.raises(AttributeError): # should raise due to ambiguity + dummy = double_wrapped.var_b + with pytest.raises(AttributeError): # should raise as does not exist + dummy = double_wrapped.nonexistent_attribute + del dummy # keep linter happy diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py new file mode 100644 index 0000000..fbcee54 --- /dev/null +++ b/tests/test_vec_normalize.py @@ -0,0 +1,41 @@ +import gym +import numpy as np + +from torchy_baselines.common.running_mean_std import RunningMeanStd +from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv +from torchy_baselines.common.vec_env.vec_normalize import VecNormalize + +ENV_ID = 'Pendulum-v0' + + +def test_runningmeanstd(): + """Test RunningMeanStd object""" + for (x_1, x_2, x_3) in [ + (np.random.randn(3), np.random.randn(4), np.random.randn(5)), + (np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2))]: + rms = RunningMeanStd(epsilon=0.0, shape=x_1.shape[1:]) + + x_cat = np.concatenate([x_1, x_2, x_3], axis=0) + moments_1 = [x_cat.mean(axis=0), x_cat.var(axis=0)] + rms.update(x_1) + rms.update(x_2) + rms.update(x_3) + moments_2 = [rms.mean, rms.var] + + assert np.allclose(moments_1, moments_2) + + +def test_vec_env(): + """Test VecNormalize Object""" + + def make_env(): + return gym.make(ENV_ID) + + env = DummyVecEnv([make_env]) + env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.) + _, done = env.reset(), [False] + obs = None + while not done[0]: + actions = [env.action_space.sample()] + obs, _, done, _ = env.step(actions) + assert np.max(obs) <= 10 diff --git a/torchy_baselines/common/running_mean_std.py b/torchy_baselines/common/running_mean_std.py new file mode 100644 index 0000000..d6a03d6 --- /dev/null +++ b/torchy_baselines/common/running_mean_std.py @@ -0,0 +1,37 @@ +import numpy as np + + +class RunningMeanStd(object): + def __init__(self, epsilon=1e-4, shape=()): + """ + calulates the running mean and std of a data stream + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + + :param epsilon: (float) helps with arithmetic issues + :param shape: (tuple) the shape of the data stream's output + """ + self.mean = np.zeros(shape, 'float64') + self.var = np.ones(shape, 'float64') + self.count = epsilon + + def update(self, arr): + batch_mean = np.mean(arr, axis=0) + batch_var = np.var(arr, axis=0) + batch_count = arr.shape[0] + self.update_from_moments(batch_mean, batch_var, batch_count) + + def update_from_moments(self, batch_mean, batch_var, batch_count): + delta = batch_mean - self.mean + tot_count = self.count + batch_count + + new_mean = self.mean + delta * batch_count / tot_count + m_a = self.var * self.count + m_b = batch_var * batch_count + m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) + new_var = m_2 / (self.count + batch_count) + + new_count = batch_count + self.count + + self.mean = new_mean + self.var = new_var + self.count = new_count diff --git a/torchy_baselines/common/vec_env/__init__.py b/torchy_baselines/common/vec_env/__init__.py new file mode 100644 index 0000000..3a58d12 --- /dev/null +++ b/torchy_baselines/common/vec_env/__init__.py @@ -0,0 +1,7 @@ +# flake8: noqa F401 +from torchy_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, \ + CloudpickleWrapper +from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv +from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv +from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack +from torchy_baselines.common.vec_env.vec_normalize import VecNormalize diff --git a/torchy_baselines/common/vec_env/base_vec_env.py b/torchy_baselines/common/vec_env/base_vec_env.py new file mode 100644 index 0000000..26fec4b --- /dev/null +++ b/torchy_baselines/common/vec_env/base_vec_env.py @@ -0,0 +1,301 @@ +from abc import ABCMeta, abstractmethod +import inspect +import pickle + +import cloudpickle + + +class AlreadySteppingError(Exception): + """ + Raised when an asynchronous step is running while + step_async() is called again. + """ + + def __init__(self): + msg = 'already running an async step' + Exception.__init__(self, msg) + + +class NotSteppingError(Exception): + """ + Raised when an asynchronous step is not running but + step_wait() is called. + """ + + def __init__(self): + msg = 'not running an async step' + Exception.__init__(self, msg) + + +class VecEnv(object): + """ + An abstract asynchronous, vectorized environment. + + :param num_envs: (int) the number of environments + :param observation_space: (Gym Space) the observation space + :param action_space: (Gym Space) the action space + """ + metadata = { + 'render.modes': ['human', 'rgb_array'] + } + + __metaclass__ = ABCMeta + + def __init__(self, num_envs, observation_space, action_space): + self.num_envs = num_envs + self.observation_space = observation_space + self.action_space = action_space + + @abstractmethod + def reset(self): + """ + Reset all the environments and return an array of + observations, or a tuple of observation arrays. + + If step_async is still doing work, that work will + be cancelled and step_wait() should not be called + until step_async() is invoked again. + + :return: ([int] or [float]) observation + """ + pass + + @abstractmethod + def step_async(self, actions): + """ + Tell all the environments to start taking a step + with the given actions. + Call step_wait() to get the results of the step. + + You should not call this if a step_async run is + already pending. + """ + pass + + @abstractmethod + def step_wait(self): + """ + Wait for the step taken with step_async(). + + :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information + """ + pass + + @abstractmethod + def close(self): + """ + Clean up the environment's resources. + """ + pass + + @abstractmethod + def get_attr(self, attr_name, indices=None): + """ + Return attribute from vectorized environment. + + :param attr_name: (str) The name of the attribute whose value to return + :param indices: (list,int) Indices of envs to get attribute from + :return: (list) List of values of 'attr_name' in all environments + """ + pass + + @abstractmethod + def set_attr(self, attr_name, value, indices=None): + """ + Set attribute inside vectorized environments. + + :param attr_name: (str) The name of attribute to assign new value + :param value: (obj) Value to assign to `attr_name` + :param indices: (list,int) Indices of envs to assign value + :return: (NoneType) + """ + pass + + @abstractmethod + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + """ + Call instance methods of vectorized environments. + + :param method_name: (str) The name of the environment method to invoke. + :param indices: (list,int) Indices of envs whose method to call + :param method_args: (tuple) Any positional arguments to provide in the call + :param method_kwargs: (dict) Any keyword arguments to provide in the call + :return: (list) List of items returned by the environment's method call + """ + pass + + def step(self, actions): + """ + Step the environments with the given action + + :param actions: ([int] or [float]) the action + :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information + """ + self.step_async(actions) + return self.step_wait() + + def get_images(self): + """ + Return RGB images from each environment + """ + raise NotImplementedError + + def render(self, *args, **kwargs): + """ + Gym environment rendering + + :param mode: (str) the rendering type + """ + raise NotImplementedError() + + @property + def unwrapped(self): + if isinstance(self, VecEnvWrapper): + return self.venv.unwrapped + else: + return self + + def getattr_depth_check(self, name, already_found): + """Check if an attribute reference is being hidden in a recursive call to __getattr__ + + :param name: (str) name of attribute to check for + :param already_found: (bool) whether this attribute has already been found in a wrapper + :return: (str or None) name of module whose attribute is being shadowed, if any. + """ + if hasattr(self, name) and already_found: + return "{0}.{1}".format(type(self).__module__, type(self).__name__) + else: + return None + + def _get_indices(self, indices): + """ + Convert a flexibly-typed reference to environment indices to an implied list of indices. + + :param indices: (None,int,Iterable) refers to indices of envs. + :return: (list) the implied list of indices. + """ + if indices is None: + indices = range(self.num_envs) + elif isinstance(indices, int): + indices = [indices] + return indices + + +class VecEnvWrapper(VecEnv): + """ + Vectorized environment base class + + :param venv: (VecEnv) the vectorized environment to wrap + :param observation_space: (Gym Space) the observation space (can be None to load from venv) + :param action_space: (Gym Space) the action space (can be None to load from venv) + """ + + def __init__(self, venv, observation_space=None, action_space=None): + self.venv = venv + VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, + action_space=action_space or venv.action_space) + self.class_attributes = dict(inspect.getmembers(self.__class__)) + + def step_async(self, actions): + self.venv.step_async(actions) + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def step_wait(self): + pass + + def close(self): + return self.venv.close() + + def render(self, *args, **kwargs): + return self.venv.render(*args, **kwargs) + + def get_images(self): + return self.venv.get_images() + + def get_attr(self, attr_name, indices=None): + return self.venv.get_attr(attr_name, indices) + + def set_attr(self, attr_name, value, indices=None): + return self.venv.set_attr(attr_name, value, indices) + + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) + + def __getattr__(self, name): + """Find attribute from wrapped venv(s) if this wrapper does not have it. + Useful for accessing attributes from venvs which are wrapped with multiple wrappers + which have unique attributes of interest. + """ + blocked_class = self.getattr_depth_check(name, already_found=False) + if blocked_class is not None: + own_class = "{0}.{1}".format(type(self).__module__, type(self).__name__) + format_str = ("Error: Recursive attribute lookup for {0} from {1} is " + "ambiguous and hides attribute from {2}") + raise AttributeError(format_str.format(name, own_class, blocked_class)) + + return self.getattr_recursive(name) + + def _get_all_attributes(self): + """Get all (inherited) instance and class attributes + + :return: (dict) all_attributes + """ + all_attributes = self.__dict__.copy() + all_attributes.update(self.class_attributes) + return all_attributes + + def getattr_recursive(self, name): + """Recursively check wrappers to find attribute. + + :param name (str) name of attribute to look for + :return: (object) attribute + """ + all_attributes = self._get_all_attributes() + if name in all_attributes: # attribute is present in this wrapper + attr = getattr(self, name) + elif hasattr(self.venv, 'getattr_recursive'): + # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr + # to avoid a duplicate call to getattr_depth_check. + attr = self.venv.getattr_recursive(name) + else: # attribute not present, child is an unwrapped VecEnv + attr = getattr(self.venv, name) + + return attr + + def getattr_depth_check(self, name, already_found): + """See base class. + + :return: (str or None) name of module whose attribute is being shadowed, if any. + """ + all_attributes = self._get_all_attributes() + if name in all_attributes and already_found: + # this venv's attribute is being hidden because of a higher venv. + shadowed_wrapper_class = "{0}.{1}".format(type(self).__module__, type(self).__name__) + elif name in all_attributes and not already_found: + # we have found the first reference to the attribute. Now check for duplicates. + shadowed_wrapper_class = self.venv.getattr_depth_check(name, True) + else: + # this wrapper does not have the attribute. Keep searching. + shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found) + + return shadowed_wrapper_class + + +class CloudpickleWrapper(object): + def __init__(self, var): + """ + Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) + + :param var: (Any) the variable you wish to wrap for pickling with cloudpickle + """ + self.var = var + + def __getstate__(self): + return cloudpickle.dumps(self.var) + + def __setstate__(self, obs): + self.var = pickle.loads(obs) diff --git a/torchy_baselines/common/vec_env/dummy_vec_env.py b/torchy_baselines/common/vec_env/dummy_vec_env.py new file mode 100644 index 0000000..a257b51 --- /dev/null +++ b/torchy_baselines/common/vec_env/dummy_vec_env.py @@ -0,0 +1,107 @@ +from collections import OrderedDict +import numpy as np + +from torchy_baselines.common.vec_env import VecEnv +from torchy_baselines.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info + + +class DummyVecEnv(VecEnv): + """ + Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current + Python process. This is useful for computationally simple environment such as ``cartpole-v1``, as the overhead of + multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that + require a vectorized environment, but that you want a single environments to train with. + + :param env_fns: ([Gym Environment]) the list of environments to vectorize + """ + + def __init__(self, env_fns): + self.envs = [fn() for fn in env_fns] + env = self.envs[0] + VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) + obs_space = env.observation_space + self.keys, shapes, dtypes = obs_space_info(obs_space) + + self.buf_obs = OrderedDict([ + (k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) + for k in self.keys]) + self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) + self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) + self.buf_infos = [{} for _ in range(self.num_envs)] + self.actions = None + self.metadata = env.metadata + + def step_async(self, actions): + self.actions = actions + + def step_wait(self): + for env_idx in range(self.num_envs): + obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\ + self.envs[env_idx].step(self.actions[env_idx]) + if self.buf_dones[env_idx]: + # save final observation where user can get it, then reset + self.buf_infos[env_idx]['terminal_observation'] = obs + obs = self.envs[env_idx].reset() + self._save_obs(env_idx, obs) + return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), + self.buf_infos.copy()) + + def reset(self): + for env_idx in range(self.num_envs): + obs = self.envs[env_idx].reset() + self._save_obs(env_idx, obs) + return self._obs_from_buf() + + def seed(self, seed, indices=None): + """ + :param seed: (int or [int]) + :param indices: ([int]) + """ + indices = self._get_indices(indices) + if not hasattr(seed, 'len'): + seed = [seed] * len(indices) + assert len(seed) == len(indices) + return [self.envs[i].seed(seed[i]) for i in indices] + + def close(self): + for env in self.envs: + env.close() + + def get_images(self): + return [env.render(mode='rgb_array') for env in self.envs] + + def render(self, *args, **kwargs): + if self.num_envs == 1: + return self.envs[0].render(*args, **kwargs) + else: + return super().render(*args, **kwargs) + + def _save_obs(self, env_idx, obs): + for key in self.keys: + if key is None: + self.buf_obs[key][env_idx] = obs + else: + self.buf_obs[key][env_idx] = obs[key] + + def _obs_from_buf(self): + return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) + + def get_attr(self, attr_name, indices=None): + """Return attribute from vectorized environment (see base class).""" + target_envs = self._get_target_envs(indices) + return [getattr(env_i, attr_name) for env_i in target_envs] + + def set_attr(self, attr_name, value, indices=None): + """Set attribute inside vectorized environments (see base class).""" + target_envs = self._get_target_envs(indices) + for env_i in target_envs: + setattr(env_i, attr_name, value) + + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + """Call instance methods of vectorized environments.""" + target_envs = self._get_target_envs(indices) + return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] + + def _get_target_envs(self, indices): + indices = self._get_indices(indices) + return [self.envs[i] for i in indices] diff --git a/torchy_baselines/common/vec_env/subproc_vec_env.py b/torchy_baselines/common/vec_env/subproc_vec_env.py new file mode 100644 index 0000000..b746dac --- /dev/null +++ b/torchy_baselines/common/vec_env/subproc_vec_env.py @@ -0,0 +1,232 @@ +import multiprocessing +from collections import OrderedDict + +import gym +import numpy as np + +from torchy_baselines.common.vec_env import VecEnv, CloudpickleWrapper + + +def _worker(remote, parent_remote, env_fn_wrapper): + parent_remote.close() + env = env_fn_wrapper.var() + while True: + try: + cmd, data = remote.recv() + if cmd == 'step': + observation, reward, done, info = env.step(data) + if done: + # save final observation where user can get it, then reset + info['terminal_observation'] = observation + observation = env.reset() + remote.send((observation, reward, done, info)) + elif cmd == 'reset': + observation = env.reset() + remote.send(observation) + elif cmd == 'render': + remote.send(env.render(*data[0], **data[1])) + elif cmd == 'close': + remote.close() + break + elif cmd == 'get_spaces': + remote.send((env.observation_space, env.action_space)) + elif cmd == 'env_method': + method = getattr(env, data[0]) + remote.send(method(*data[1], **data[2])) + elif cmd == 'get_attr': + remote.send(getattr(env, data)) + elif cmd == 'set_attr': + remote.send(setattr(env, data[0], data[1])) + else: + raise NotImplementedError + except EOFError: + break + +def tile_images(img_nhwc): + """ + Tile N images into one big PxQ image + (P,Q) are chosen to be as close as possible, and if N + is square, then P=Q. + + :param img_nhwc: (list) list or array of images, ndim=4 once turned into array. img nhwc + n = batch index, h = height, w = width, c = channel + :return: (numpy float) img_HWc, ndim=3 + """ + img_nhwc = np.asarray(img_nhwc) + n_images, height, width, n_channels = img_nhwc.shape + # new_height was named H before + new_height = int(np.ceil(np.sqrt(n_images))) + # new_width was named W before + new_width = int(np.ceil(float(n_images) / new_height)) + img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)]) + # img_HWhwc + out_image = img_nhwc.reshape(new_height, new_width, height, width, n_channels) + # img_HhWwc + out_image = out_image.transpose(0, 2, 1, 3, 4) + # img_Hh_Ww_c + out_image = out_image.reshape(new_height * height, new_width * width, n_channels) + return out_image + + + +class SubprocVecEnv(VecEnv): + """ + Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own + process, allowing significant speed up when the environment is computationally complex. + + For performance reasons, if your environment is not IO bound, the number of environments should not exceed the + number of logical cores on your CPU. + + .. warning:: + + Only 'forkserver' and 'spawn' start methods are thread-safe, + which is important when TensorFlow sessions or other non thread-safe + libraries are used in the parent (see issue #217). However, compared to + 'fork' they incur a small start-up cost and have restrictions on + global variables. With those methods, users must wrap the code in an + ``if __name__ == "__main__":`` block. + For more information, see the multiprocessing documentation. + + :param env_fns: ([Gym Environment]) Environments to run in subprocesses + :param start_method: (str) method used to start the subprocesses. + Must be one of the methods returned by multiprocessing.get_all_start_methods(). + Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. + """ + + def __init__(self, env_fns, start_method=None): + self.waiting = False + self.closed = False + n_envs = len(env_fns) + + if start_method is None: + # Fork is not a thread safe method (see issue #217) + # but is more user friendly (does not require to wrap the code in + # a `if __name__ == "__main__":`) + forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods() + start_method = 'forkserver' if forkserver_available else 'spawn' + ctx = multiprocessing.get_context(start_method) + + self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) + self.processes = [] + for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): + args = (work_remote, remote, CloudpickleWrapper(env_fn)) + # daemon=True: if the main process crashes, we should not cause things to hang + process = ctx.Process(target=_worker, args=args, daemon=True) + process.start() + self.processes.append(process) + work_remote.close() + + self.remotes[0].send(('get_spaces', None)) + observation_space, action_space = self.remotes[0].recv() + VecEnv.__init__(self, len(env_fns), observation_space, action_space) + + def step_async(self, actions): + for remote, action in zip(self.remotes, actions): + remote.send(('step', action)) + self.waiting = True + + def step_wait(self): + results = [remote.recv() for remote in self.remotes] + self.waiting = False + obs, rews, dones, infos = zip(*results) + return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos + + def reset(self): + for remote in self.remotes: + remote.send(('reset', None)) + obs = [remote.recv() for remote in self.remotes] + return _flatten_obs(obs, self.observation_space) + + def close(self): + if self.closed: + return + if self.waiting: + for remote in self.remotes: + remote.recv() + for remote in self.remotes: + remote.send(('close', None)) + for process in self.processes: + process.join() + self.closed = True + + def render(self, mode='human', *args, **kwargs): + for pipe in self.remotes: + # gather images from subprocesses + # `mode` will be taken into account later + pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs}))) + imgs = [pipe.recv() for pipe in self.remotes] + # Create a big image by tiling images from subprocesses + bigimg = tile_images(imgs) + if mode == 'human': + import cv2 + cv2.imshow('vecenv', bigimg[:, :, ::-1]) + cv2.waitKey(1) + elif mode == 'rgb_array': + return bigimg + else: + raise NotImplementedError + + def get_images(self): + for pipe in self.remotes: + pipe.send(('render', {"mode": 'rgb_array'})) + imgs = [pipe.recv() for pipe in self.remotes] + return imgs + + def get_attr(self, attr_name, indices=None): + """Return attribute from vectorized environment (see base class).""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(('get_attr', attr_name)) + return [remote.recv() for remote in target_remotes] + + def set_attr(self, attr_name, value, indices=None): + """Set attribute inside vectorized environments (see base class).""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(('set_attr', (attr_name, value))) + for remote in target_remotes: + remote.recv() + + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + """Call instance methods of vectorized environments.""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(('env_method', (method_name, method_args, method_kwargs))) + return [remote.recv() for remote in target_remotes] + + def _get_target_remotes(self, indices): + """ + Get the connection object needed to communicate with the wanted + envs that are in subprocesses. + + :param indices: (None,int,Iterable) refers to indices of envs. + :return: ([multiprocessing.Connection]) Connection object to communicate between processes. + """ + indices = self._get_indices(indices) + return [self.remotes[i] for i in indices] + + +def _flatten_obs(obs, space): + """ + Flatten observations, depending on the observation space. + + :param obs: (list or tuple where X is dict, tuple or ndarray) observations. + A list or tuple of observations, one per environment. + Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. + :return (OrderedDict, tuple or ndarray) flattened observations. + A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. + Each NumPy array has the environment index as its first axis. + """ + assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" + assert len(obs) > 0, "need observations from at least one environment" + + if isinstance(space, gym.spaces.Dict): + assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" + return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + elif isinstance(space, gym.spaces.Tuple): + assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" + obs_len = len(space.spaces) + return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len))) + else: + return np.stack(obs) diff --git a/torchy_baselines/common/vec_env/util.py b/torchy_baselines/common/vec_env/util.py new file mode 100644 index 0000000..03ce286 --- /dev/null +++ b/torchy_baselines/common/vec_env/util.py @@ -0,0 +1,73 @@ +""" +Helpers for dealing with vectorized environments. +""" + +from collections import OrderedDict + +import gym +import numpy as np + + +def copy_obs_dict(obs): + """ + Deep-copy a dict of numpy arrays. + + :param obs: (OrderedDict): a dict of numpy arrays. + :return (OrderedDict) a dict of copied numpy arrays. + """ + assert isinstance(obs, OrderedDict), "unexpected type for observations '{}'".format(type(obs)) + return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) + + +def dict_to_obs(space, obs_dict): + """ + Convert an internal representation raw_obs into the appropriate type + specified by space. + + :param space: (gym.spaces.Space) an observation space. + :param obs_dict: (OrderedDict) a dict of numpy arrays. + :return (ndarray, tuple or dict): returns an observation + of the same type as space. If space is Dict, function is identity; + if space is Tuple, converts dict to Tuple; otherwise, space is + unstructured and returns the value raw_obs[None]. + """ + if isinstance(space, gym.spaces.Dict): + return obs_dict + elif isinstance(space, gym.spaces.Tuple): + assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space" + return tuple((obs_dict[i] for i in range(len(space.spaces)))) + else: + assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" + return obs_dict[None] + + +def obs_space_info(obs_space): + """ + Get dict-structured information about a gym.Space. + + Dict spaces are represented directly by their dict of subspaces. + Tuple spaces are converted into a dict with keys indexing into the tuple. + Unstructured spaces are represented by {None: obs_space}. + + :param obs_space: (gym.spaces.Space) an observation space + :return (tuple) A tuple (keys, shapes, dtypes): + keys: a list of dict keys. + shapes: a dict mapping keys to shapes. + dtypes: a dict mapping keys to dtypes. + """ + if isinstance(obs_space, gym.spaces.Dict): + assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" + subspaces = obs_space.spaces + elif isinstance(obs_space, gym.spaces.Tuple): + subspaces = {i: space for i, space in enumerate(obs_space.spaces)} + else: + assert not hasattr(obs_space, 'spaces'), "Unsupported structured space '{}'".format(type(obs_space)) + subspaces = {None: obs_space} + keys = [] + shapes = {} + dtypes = {} + for key, box in subspaces.items(): + keys.append(key) + shapes[key] = box.shape + dtypes[key] = box.dtype + return keys, shapes, dtypes diff --git a/torchy_baselines/common/vec_env/vec_frame_stack.py b/torchy_baselines/common/vec_env/vec_frame_stack.py new file mode 100644 index 0000000..2610f42 --- /dev/null +++ b/torchy_baselines/common/vec_env/vec_frame_stack.py @@ -0,0 +1,55 @@ +import warnings + +import numpy as np +from gym import spaces + +from torchy_baselines.common.vec_env import VecEnvWrapper + + +class VecFrameStack(VecEnvWrapper): + """ + Frame stacking wrapper for vectorized environment + + :param venv: (VecEnv) the vectorized environment to wrap + :param n_stack: (int) Number of frames to stack + """ + + def __init__(self, venv, n_stack): + self.venv = venv + self.n_stack = n_stack + wrapped_obs_space = venv.observation_space + low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1) + high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1) + self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) + observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) + VecEnvWrapper.__init__(self, venv, observation_space=observation_space) + + def step_wait(self): + observations, rewards, dones, infos = self.venv.step_wait() + last_ax_size = observations.shape[-1] + self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1) + for i, done in enumerate(dones): + if done: + if 'terminal_observation' in infos[i]: + old_terminal = infos[i]['terminal_observation'] + new_terminal = np.concatenate( + (self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1) + infos[i]['terminal_observation'] = new_terminal + else: + warnings.warn( + "VecFrameStack wrapping a VecEnv without terminal_observation info") + self.stackedobs[i] = 0 + self.stackedobs[..., -observations.shape[-1]:] = observations + return self.stackedobs, rewards, dones, infos + + def reset(self): + """ + Reset all environments + """ + obs = self.venv.reset() + self.stackedobs[...] = 0 + self.stackedobs[..., -obs.shape[-1]:] = obs + return self.stackedobs + + def close(self): + self.venv.close() diff --git a/torchy_baselines/common/vec_env/vec_normalize.py b/torchy_baselines/common/vec_env/vec_normalize.py new file mode 100644 index 0000000..6f2b2b3 --- /dev/null +++ b/torchy_baselines/common/vec_env/vec_normalize.py @@ -0,0 +1,105 @@ +import pickle + +import numpy as np + +from torchy_baselines.common.vec_env import VecEnvWrapper +from torchy_baselines.common.running_mean_std import RunningMeanStd + + +class VecNormalize(VecEnvWrapper): + """ + A moving average, normalizing wrapper for vectorized environment. + has support for saving/loading moving average, + + :param venv: (VecEnv) the vectorized environment to wrap + :param training: (bool) Whether to update or not the moving average + :param norm_obs: (bool) Whether to normalize observation or not (default: True) + :param norm_reward: (bool) Whether to normalize rewards or not (default: True) + :param clip_obs: (float) Max absolute value for observation + :param clip_reward: (float) Max value absolute for discounted reward + :param gamma: (float) discount factor + :param epsilon: (float) To avoid division by zero + """ + + def __init__(self, venv, training=True, norm_obs=True, norm_reward=True, + clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8): + VecEnvWrapper.__init__(self, venv) + self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) + self.ret_rms = RunningMeanStd(shape=()) + self.clip_obs = clip_obs + self.clip_reward = clip_reward + # Returns: discounted rewards + self.ret = np.zeros(self.num_envs) + self.gamma = gamma + self.epsilon = epsilon + self.training = training + self.norm_obs = norm_obs + self.norm_reward = norm_reward + self.old_obs = np.array([]) + + def step_wait(self): + """ + Apply sequence of actions to sequence of environments + actions -> (observations, rewards, news) + + where 'news' is a boolean vector indicating whether each element is new. + """ + obs, rews, news, infos = self.venv.step_wait() + self.ret = self.ret * self.gamma + rews + self.old_obs = obs + obs = self._normalize_observation(obs) + if self.norm_reward: + if self.training: + self.ret_rms.update(self.ret) + rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) + self.ret[news] = 0 + return obs, rews, news, infos + + def _normalize_observation(self, obs): + """ + :param obs: (numpy tensor) + """ + if self.norm_obs: + if self.training: + self.obs_rms.update(obs) + obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, + self.clip_obs) + return obs + else: + return obs + + def get_original_obs(self): + """ + returns the unnormalized observation + + :return: (numpy float) + """ + return self.old_obs + + def reset(self): + """ + Reset all environments + """ + obs = self.venv.reset() + if len(np.array(obs).shape) == 1: # for when num_cpu is 1 + self.old_obs = [obs] + else: + self.old_obs = obs + self.ret = np.zeros(self.num_envs) + return self._normalize_observation(obs) + + def save_running_average(self, path): + """ + :param path: (str) path to log dir + """ + for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']): + with open("{}/{}.pkl".format(path, name), 'wb') as file_handler: + pickle.dump(rms, file_handler) + + def load_running_average(self, path): + """ + :param path: (str) path to log dir + """ + for name in ['obs_rms', 'ret_rms']: + with open("{}/{}.pkl".format(path, name), 'rb') as file_handler: + setattr(self, name, pickle.load(file_handler))