mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Add stable-baselines VecEnvs
This commit is contained in:
parent
cc4380eccd
commit
56053bc692
10 changed files with 1284 additions and 0 deletions
326
tests/test_vec_envs.py
Normal file
326
tests/test_vec_envs.py
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
import collections
|
||||
import functools
|
||||
import itertools
|
||||
import multiprocessing
|
||||
|
||||
import pytest
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecFrameStack
|
||||
|
||||
N_ENVS = 3
|
||||
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
|
||||
VEC_ENV_WRAPPERS = [None, VecNormalize, VecFrameStack]
|
||||
|
||||
|
||||
class CustomGymEnv(gym.Env):
|
||||
def __init__(self, space):
|
||||
"""
|
||||
Custom gym environment for testing purposes
|
||||
"""
|
||||
self.action_space = space
|
||||
self.observation_space = space
|
||||
self.current_step = 0
|
||||
self.ep_length = 4
|
||||
|
||||
def reset(self):
|
||||
self.current_step = 0
|
||||
self._choose_next_state()
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
reward = 1
|
||||
self._choose_next_state()
|
||||
self.current_step += 1
|
||||
done = self.current_step >= self.ep_length
|
||||
return self.state, reward, done, {}
|
||||
|
||||
def _choose_next_state(self):
|
||||
self.state = self.observation_space.sample()
|
||||
|
||||
def render(self, mode='human'):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def custom_method(dim_0=1, dim_1=1):
|
||||
"""
|
||||
Dummy method to test call to custom method
|
||||
from VecEnv
|
||||
|
||||
:param dim_0: (int)
|
||||
:param dim_1: (int)
|
||||
:return: (np.ndarray)
|
||||
"""
|
||||
return np.ones((dim_0, dim_1))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
|
||||
@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
|
||||
def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
|
||||
"""Test access to methods/attributes of vectorized environments"""
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
|
||||
|
||||
if vec_env_wrapper is not None:
|
||||
if vec_env_wrapper == VecFrameStack:
|
||||
vec_env = vec_env_wrapper(vec_env, n_stack=2)
|
||||
else:
|
||||
vec_env = vec_env_wrapper(vec_env)
|
||||
|
||||
env_method_results = vec_env.env_method('custom_method', 1, indices=None, dim_1=2)
|
||||
setattr_results = []
|
||||
# Set current_step to an arbitrary value
|
||||
for env_idx in range(N_ENVS):
|
||||
setattr_results.append(vec_env.set_attr('current_step', env_idx, indices=env_idx))
|
||||
# Retrieve the value for each environment
|
||||
getattr_results = vec_env.get_attr('current_step')
|
||||
|
||||
assert len(env_method_results) == N_ENVS
|
||||
assert len(setattr_results) == N_ENVS
|
||||
assert len(getattr_results) == N_ENVS
|
||||
|
||||
for env_idx in range(N_ENVS):
|
||||
assert (env_method_results[env_idx] == np.ones((1, 2))).all()
|
||||
assert setattr_results[env_idx] is None
|
||||
assert getattr_results[env_idx] == env_idx
|
||||
|
||||
# Call env_method on a subset of the VecEnv
|
||||
env_method_subset = vec_env.env_method('custom_method', 1, indices=[0, 2], dim_1=3)
|
||||
assert (env_method_subset[0] == np.ones((1, 3))).all()
|
||||
assert (env_method_subset[1] == np.ones((1, 3))).all()
|
||||
assert len(env_method_subset) == 2
|
||||
|
||||
|
||||
# Test to change value for all the environments
|
||||
setattr_result = vec_env.set_attr('current_step', 42, indices=None)
|
||||
getattr_result = vec_env.get_attr('current_step')
|
||||
assert setattr_result is None
|
||||
assert getattr_result == [42 for _ in range(N_ENVS)]
|
||||
|
||||
# Additional tests for setattr that does not affect all the environments
|
||||
vec_env.reset()
|
||||
setattr_result = vec_env.set_attr('current_step', 12, indices=[0, 1])
|
||||
getattr_result = vec_env.get_attr('current_step')
|
||||
getattr_result_subset = vec_env.get_attr('current_step', indices=[0, 1])
|
||||
assert setattr_result is None
|
||||
assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)]
|
||||
assert getattr_result_subset == [12, 12]
|
||||
assert vec_env.get_attr('current_step', indices=[0, 2]) == [12, 0]
|
||||
|
||||
vec_env.reset()
|
||||
# Change value only for first and last environment
|
||||
setattr_result = vec_env.set_attr('current_step', 12, indices=[0, -1])
|
||||
getattr_result = vec_env.get_attr('current_step')
|
||||
assert setattr_result is None
|
||||
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]
|
||||
assert vec_env.get_attr('current_step', indices=[-1]) == [12]
|
||||
|
||||
vec_env.close()
|
||||
|
||||
|
||||
class StepEnv(gym.Env):
|
||||
def __init__(self, max_steps):
|
||||
"""Gym environment for testing that terminal observation is inserted
|
||||
correctly."""
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]),
|
||||
dtype='int')
|
||||
self.max_steps = max_steps
|
||||
self.current_step = 0
|
||||
|
||||
def reset(self):
|
||||
self.current_step = 0
|
||||
return np.array([self.current_step], dtype='int')
|
||||
|
||||
def step(self, action):
|
||||
prev_step = self.current_step
|
||||
self.current_step += 1
|
||||
done = self.current_step >= self.max_steps
|
||||
return np.array([prev_step], dtype='int'), 0.0, done, {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
|
||||
@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
|
||||
def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
|
||||
"""Test that 'terminal_observation' gets added to info dict upon
|
||||
termination."""
|
||||
step_nums = [i + 5 for i in range(N_ENVS)]
|
||||
vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums])
|
||||
|
||||
if vec_env_wrapper is not None:
|
||||
if vec_env_wrapper == VecFrameStack:
|
||||
vec_env = vec_env_wrapper(vec_env, n_stack=2)
|
||||
else:
|
||||
vec_env = vec_env_wrapper(vec_env)
|
||||
|
||||
zero_acts = np.zeros((N_ENVS,), dtype='int')
|
||||
prev_obs_b = vec_env.reset()
|
||||
for step_num in range(1, max(step_nums) + 1):
|
||||
obs_b, _, done_b, info_b = vec_env.step(zero_acts)
|
||||
assert len(obs_b) == N_ENVS
|
||||
assert len(done_b) == N_ENVS
|
||||
assert len(info_b) == N_ENVS
|
||||
env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums)
|
||||
for prev_obs, obs, done, info, final_step_num in env_iter:
|
||||
assert done == (step_num == final_step_num)
|
||||
if not done:
|
||||
assert 'terminal_observation' not in info
|
||||
else:
|
||||
terminal_obs = info['terminal_observation']
|
||||
|
||||
# do some rough ordering checks that should work for all
|
||||
# wrappers, including VecNormalize
|
||||
assert np.all(prev_obs < terminal_obs)
|
||||
assert np.all(obs < prev_obs)
|
||||
|
||||
if not isinstance(vec_env, VecNormalize):
|
||||
# more precise tests that we can't do with VecNormalize
|
||||
# (which changes observation values)
|
||||
assert np.all(prev_obs + 1 == terminal_obs)
|
||||
assert np.all(obs == 0)
|
||||
|
||||
prev_obs_b = obs_b
|
||||
|
||||
vec_env.close()
|
||||
|
||||
|
||||
SPACES = collections.OrderedDict([
|
||||
('discrete', gym.spaces.Discrete(2)),
|
||||
('multidiscrete', gym.spaces.MultiDiscrete([2, 3])),
|
||||
('multibinary', gym.spaces.MultiBinary(3)),
|
||||
('continuous', gym.spaces.Box(low=np.zeros(2), high=np.ones(2))),
|
||||
])
|
||||
|
||||
def check_vecenv_spaces(vec_env_class, space, obs_assert):
|
||||
"""Helper method to check observation spaces in vectorized environments."""
|
||||
def make_env():
|
||||
return CustomGymEnv(space)
|
||||
|
||||
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
|
||||
obs = vec_env.reset()
|
||||
obs_assert(obs)
|
||||
|
||||
dones = [False] * N_ENVS
|
||||
while not any(dones):
|
||||
actions = [vec_env.action_space.sample() for _ in range(N_ENVS)]
|
||||
obs, _rews, dones, _infos = vec_env.step(actions)
|
||||
obs_assert(obs)
|
||||
vec_env.close()
|
||||
|
||||
|
||||
def check_vecenv_obs(obs, space):
|
||||
"""Helper method to check observations from multiple environments each belong to
|
||||
the appropriate observation space."""
|
||||
assert obs.shape[0] == N_ENVS
|
||||
for value in obs:
|
||||
assert space.contains(value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('vec_env_class,space', itertools.product(VEC_ENV_CLASSES, SPACES.values()))
|
||||
def test_vecenv_single_space(vec_env_class, space):
|
||||
def obs_assert(obs):
|
||||
return check_vecenv_obs(obs, space)
|
||||
|
||||
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
||||
|
||||
|
||||
class _UnorderedDictSpace(gym.spaces.Dict):
|
||||
"""Like DictSpace, but returns an unordered dict when sampling."""
|
||||
def sample(self):
|
||||
return dict(super().sample())
|
||||
|
||||
|
||||
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
|
||||
def test_vecenv_dict_spaces(vec_env_class):
|
||||
"""Test dictionary observation spaces with vectorized environments."""
|
||||
space = gym.spaces.Dict(SPACES)
|
||||
|
||||
def obs_assert(obs):
|
||||
assert isinstance(obs, collections.OrderedDict)
|
||||
assert obs.keys() == space.spaces.keys()
|
||||
for key, values in obs.items():
|
||||
check_vecenv_obs(values, space.spaces[key])
|
||||
|
||||
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
||||
|
||||
unordered_space = _UnorderedDictSpace(SPACES)
|
||||
# Check that vec_env_class can accept unordered dict observations (and convert to OrderedDict)
|
||||
check_vecenv_spaces(vec_env_class, unordered_space, obs_assert)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
|
||||
def test_vecenv_tuple_spaces(vec_env_class):
|
||||
"""Test tuple observation spaces with vectorized environments."""
|
||||
space = gym.spaces.Tuple(tuple(SPACES.values()))
|
||||
|
||||
def obs_assert(obs):
|
||||
assert isinstance(obs, tuple)
|
||||
assert len(obs) == len(space.spaces)
|
||||
for values, inner_space in zip(obs, space.spaces):
|
||||
check_vecenv_obs(values, inner_space)
|
||||
|
||||
return check_vecenv_spaces(vec_env_class, space, obs_assert)
|
||||
|
||||
|
||||
def test_subproc_start_method():
|
||||
start_methods = [None]
|
||||
# Only test thread-safe methods. Others may deadlock tests! (gh/428)
|
||||
safe_methods = {'forkserver', 'spawn'}
|
||||
available_methods = multiprocessing.get_all_start_methods()
|
||||
start_methods += list(safe_methods.intersection(available_methods))
|
||||
space = gym.spaces.Discrete(2)
|
||||
|
||||
def obs_assert(obs):
|
||||
return check_vecenv_obs(obs, space)
|
||||
|
||||
for start_method in start_methods:
|
||||
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)
|
||||
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"):
|
||||
vec_env_class = functools.partial(SubprocVecEnv, start_method='illegal_method')
|
||||
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
||||
|
||||
|
||||
class CustomWrapperA(VecNormalize):
|
||||
def __init__(self, venv):
|
||||
VecNormalize.__init__(self, venv)
|
||||
self.var_a = 'a'
|
||||
|
||||
|
||||
class CustomWrapperB(VecNormalize):
|
||||
def __init__(self, venv):
|
||||
VecNormalize.__init__(self, venv)
|
||||
self.var_b = 'b'
|
||||
|
||||
def func_b(self):
|
||||
return self.var_b
|
||||
|
||||
def name_test(self):
|
||||
return self.__class__
|
||||
|
||||
class CustomWrapperBB(CustomWrapperB):
|
||||
def __init__(self, venv):
|
||||
CustomWrapperB.__init__(self, venv)
|
||||
self.var_bb = 'bb'
|
||||
|
||||
def test_vecenv_wrapper_getattr():
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
|
||||
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
|
||||
assert wrapped.var_a == 'a'
|
||||
assert wrapped.var_b == 'b'
|
||||
assert wrapped.var_bb == 'bb'
|
||||
assert wrapped.func_b() == 'b'
|
||||
assert wrapped.name_test() == CustomWrapperBB
|
||||
|
||||
double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
|
||||
dummy = double_wrapped.var_a # should not raise as it is directly defined here
|
||||
with pytest.raises(AttributeError): # should raise due to ambiguity
|
||||
dummy = double_wrapped.var_b
|
||||
with pytest.raises(AttributeError): # should raise as does not exist
|
||||
dummy = double_wrapped.nonexistent_attribute
|
||||
del dummy # keep linter happy
|
||||
41
tests/test_vec_normalize.py
Normal file
41
tests/test_vec_normalize.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.running_mean_std import RunningMeanStd
|
||||
from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from torchy_baselines.common.vec_env.vec_normalize import VecNormalize
|
||||
|
||||
ENV_ID = 'Pendulum-v0'
|
||||
|
||||
|
||||
def test_runningmeanstd():
|
||||
"""Test RunningMeanStd object"""
|
||||
for (x_1, x_2, x_3) in [
|
||||
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
|
||||
(np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2))]:
|
||||
rms = RunningMeanStd(epsilon=0.0, shape=x_1.shape[1:])
|
||||
|
||||
x_cat = np.concatenate([x_1, x_2, x_3], axis=0)
|
||||
moments_1 = [x_cat.mean(axis=0), x_cat.var(axis=0)]
|
||||
rms.update(x_1)
|
||||
rms.update(x_2)
|
||||
rms.update(x_3)
|
||||
moments_2 = [rms.mean, rms.var]
|
||||
|
||||
assert np.allclose(moments_1, moments_2)
|
||||
|
||||
|
||||
def test_vec_env():
|
||||
"""Test VecNormalize Object"""
|
||||
|
||||
def make_env():
|
||||
return gym.make(ENV_ID)
|
||||
|
||||
env = DummyVecEnv([make_env])
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
|
||||
_, done = env.reset(), [False]
|
||||
obs = None
|
||||
while not done[0]:
|
||||
actions = [env.action_space.sample()]
|
||||
obs, _, done, _ = env.step(actions)
|
||||
assert np.max(obs) <= 10
|
||||
37
torchy_baselines/common/running_mean_std.py
Normal file
37
torchy_baselines/common/running_mean_std.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
class RunningMeanStd(object):
|
||||
def __init__(self, epsilon=1e-4, shape=()):
|
||||
"""
|
||||
calulates the running mean and std of a data stream
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
|
||||
:param epsilon: (float) helps with arithmetic issues
|
||||
:param shape: (tuple) the shape of the data stream's output
|
||||
"""
|
||||
self.mean = np.zeros(shape, 'float64')
|
||||
self.var = np.ones(shape, 'float64')
|
||||
self.count = epsilon
|
||||
|
||||
def update(self, arr):
|
||||
batch_mean = np.mean(arr, axis=0)
|
||||
batch_var = np.var(arr, axis=0)
|
||||
batch_count = arr.shape[0]
|
||||
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||
|
||||
def update_from_moments(self, batch_mean, batch_var, batch_count):
|
||||
delta = batch_mean - self.mean
|
||||
tot_count = self.count + batch_count
|
||||
|
||||
new_mean = self.mean + delta * batch_count / tot_count
|
||||
m_a = self.var * self.count
|
||||
m_b = batch_var * batch_count
|
||||
m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
|
||||
new_var = m_2 / (self.count + batch_count)
|
||||
|
||||
new_count = batch_count + self.count
|
||||
|
||||
self.mean = new_mean
|
||||
self.var = new_var
|
||||
self.count = new_count
|
||||
7
torchy_baselines/common/vec_env/__init__.py
Normal file
7
torchy_baselines/common/vec_env/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# flake8: noqa F401
|
||||
from torchy_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, \
|
||||
CloudpickleWrapper
|
||||
from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from torchy_baselines.common.vec_env.vec_normalize import VecNormalize
|
||||
301
torchy_baselines/common/vec_env/base_vec_env.py
Normal file
301
torchy_baselines/common/vec_env/base_vec_env.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
import inspect
|
||||
import pickle
|
||||
|
||||
import cloudpickle
|
||||
|
||||
|
||||
class AlreadySteppingError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous step is running while
|
||||
step_async() is called again.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
msg = 'already running an async step'
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
|
||||
class NotSteppingError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous step is not running but
|
||||
step_wait() is called.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
msg = 'not running an async step'
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
|
||||
class VecEnv(object):
|
||||
"""
|
||||
An abstract asynchronous, vectorized environment.
|
||||
|
||||
:param num_envs: (int) the number of environments
|
||||
:param observation_space: (Gym Space) the observation space
|
||||
:param action_space: (Gym Space) the action space
|
||||
"""
|
||||
metadata = {
|
||||
'render.modes': ['human', 'rgb_array']
|
||||
}
|
||||
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all the environments and return an array of
|
||||
observations, or a tuple of observation arrays.
|
||||
|
||||
If step_async is still doing work, that work will
|
||||
be cancelled and step_wait() should not be called
|
||||
until step_async() is invoked again.
|
||||
|
||||
:return: ([int] or [float]) observation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_async(self, actions):
|
||||
"""
|
||||
Tell all the environments to start taking a step
|
||||
with the given actions.
|
||||
Call step_wait() to get the results of the step.
|
||||
|
||||
You should not call this if a step_async run is
|
||||
already pending.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_wait(self):
|
||||
"""
|
||||
Wait for the step taken with step_async().
|
||||
|
||||
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
"""
|
||||
Clean up the environment's resources.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_attr(self, attr_name, indices=None):
|
||||
"""
|
||||
Return attribute from vectorized environment.
|
||||
|
||||
:param attr_name: (str) The name of the attribute whose value to return
|
||||
:param indices: (list,int) Indices of envs to get attribute from
|
||||
:return: (list) List of values of 'attr_name' in all environments
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_attr(self, attr_name, value, indices=None):
|
||||
"""
|
||||
Set attribute inside vectorized environments.
|
||||
|
||||
:param attr_name: (str) The name of attribute to assign new value
|
||||
:param value: (obj) Value to assign to `attr_name`
|
||||
:param indices: (list,int) Indices of envs to assign value
|
||||
:return: (NoneType)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
"""
|
||||
Call instance methods of vectorized environments.
|
||||
|
||||
:param method_name: (str) The name of the environment method to invoke.
|
||||
:param indices: (list,int) Indices of envs whose method to call
|
||||
:param method_args: (tuple) Any positional arguments to provide in the call
|
||||
:param method_kwargs: (dict) Any keyword arguments to provide in the call
|
||||
:return: (list) List of items returned by the environment's method call
|
||||
"""
|
||||
pass
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Step the environments with the given action
|
||||
|
||||
:param actions: ([int] or [float]) the action
|
||||
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
|
||||
"""
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
Return RGB images from each environment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
"""
|
||||
Gym environment rendering
|
||||
|
||||
:param mode: (str) the rendering type
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
if isinstance(self, VecEnvWrapper):
|
||||
return self.venv.unwrapped
|
||||
else:
|
||||
return self
|
||||
|
||||
def getattr_depth_check(self, name, already_found):
|
||||
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
|
||||
|
||||
:param name: (str) name of attribute to check for
|
||||
:param already_found: (bool) whether this attribute has already been found in a wrapper
|
||||
:return: (str or None) name of module whose attribute is being shadowed, if any.
|
||||
"""
|
||||
if hasattr(self, name) and already_found:
|
||||
return "{0}.{1}".format(type(self).__module__, type(self).__name__)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_indices(self, indices):
|
||||
"""
|
||||
Convert a flexibly-typed reference to environment indices to an implied list of indices.
|
||||
|
||||
:param indices: (None,int,Iterable) refers to indices of envs.
|
||||
:return: (list) the implied list of indices.
|
||||
"""
|
||||
if indices is None:
|
||||
indices = range(self.num_envs)
|
||||
elif isinstance(indices, int):
|
||||
indices = [indices]
|
||||
return indices
|
||||
|
||||
|
||||
class VecEnvWrapper(VecEnv):
|
||||
"""
|
||||
Vectorized environment base class
|
||||
|
||||
:param venv: (VecEnv) the vectorized environment to wrap
|
||||
:param observation_space: (Gym Space) the observation space (can be None to load from venv)
|
||||
:param action_space: (Gym Space) the action space (can be None to load from venv)
|
||||
"""
|
||||
|
||||
def __init__(self, venv, observation_space=None, action_space=None):
|
||||
self.venv = venv
|
||||
VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space,
|
||||
action_space=action_space or venv.action_space)
|
||||
self.class_attributes = dict(inspect.getmembers(self.__class__))
|
||||
|
||||
def step_async(self, actions):
|
||||
self.venv.step_async(actions)
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_wait(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
return self.venv.close()
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
return self.venv.render(*args, **kwargs)
|
||||
|
||||
def get_images(self):
|
||||
return self.venv.get_images()
|
||||
|
||||
def get_attr(self, attr_name, indices=None):
|
||||
return self.venv.get_attr(attr_name, indices)
|
||||
|
||||
def set_attr(self, attr_name, value, indices=None):
|
||||
return self.venv.set_attr(attr_name, value, indices)
|
||||
|
||||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
|
||||
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
|
||||
which have unique attributes of interest.
|
||||
"""
|
||||
blocked_class = self.getattr_depth_check(name, already_found=False)
|
||||
if blocked_class is not None:
|
||||
own_class = "{0}.{1}".format(type(self).__module__, type(self).__name__)
|
||||
format_str = ("Error: Recursive attribute lookup for {0} from {1} is "
|
||||
"ambiguous and hides attribute from {2}")
|
||||
raise AttributeError(format_str.format(name, own_class, blocked_class))
|
||||
|
||||
return self.getattr_recursive(name)
|
||||
|
||||
def _get_all_attributes(self):
|
||||
"""Get all (inherited) instance and class attributes
|
||||
|
||||
:return: (dict<str, object>) all_attributes
|
||||
"""
|
||||
all_attributes = self.__dict__.copy()
|
||||
all_attributes.update(self.class_attributes)
|
||||
return all_attributes
|
||||
|
||||
def getattr_recursive(self, name):
|
||||
"""Recursively check wrappers to find attribute.
|
||||
|
||||
:param name (str) name of attribute to look for
|
||||
:return: (object) attribute
|
||||
"""
|
||||
all_attributes = self._get_all_attributes()
|
||||
if name in all_attributes: # attribute is present in this wrapper
|
||||
attr = getattr(self, name)
|
||||
elif hasattr(self.venv, 'getattr_recursive'):
|
||||
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
|
||||
# to avoid a duplicate call to getattr_depth_check.
|
||||
attr = self.venv.getattr_recursive(name)
|
||||
else: # attribute not present, child is an unwrapped VecEnv
|
||||
attr = getattr(self.venv, name)
|
||||
|
||||
return attr
|
||||
|
||||
def getattr_depth_check(self, name, already_found):
|
||||
"""See base class.
|
||||
|
||||
:return: (str or None) name of module whose attribute is being shadowed, if any.
|
||||
"""
|
||||
all_attributes = self._get_all_attributes()
|
||||
if name in all_attributes and already_found:
|
||||
# this venv's attribute is being hidden because of a higher venv.
|
||||
shadowed_wrapper_class = "{0}.{1}".format(type(self).__module__, type(self).__name__)
|
||||
elif name in all_attributes and not already_found:
|
||||
# we have found the first reference to the attribute. Now check for duplicates.
|
||||
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
|
||||
else:
|
||||
# this wrapper does not have the attribute. Keep searching.
|
||||
shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)
|
||||
|
||||
return shadowed_wrapper_class
|
||||
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
def __init__(self, var):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
|
||||
:param var: (Any) the variable you wish to wrap for pickling with cloudpickle
|
||||
"""
|
||||
self.var = var
|
||||
|
||||
def __getstate__(self):
|
||||
return cloudpickle.dumps(self.var)
|
||||
|
||||
def __setstate__(self, obs):
|
||||
self.var = pickle.loads(obs)
|
||||
107
torchy_baselines/common/vec_env/dummy_vec_env.py
Normal file
107
torchy_baselines/common/vec_env/dummy_vec_env.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.vec_env import VecEnv
|
||||
from torchy_baselines.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
|
||||
|
||||
|
||||
class DummyVecEnv(VecEnv):
|
||||
"""
|
||||
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
|
||||
Python process. This is useful for computationally simple environment such as ``cartpole-v1``, as the overhead of
|
||||
multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that
|
||||
require a vectorized environment, but that you want a single environments to train with.
|
||||
|
||||
:param env_fns: ([Gym Environment]) the list of environments to vectorize
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
self.envs = [fn() for fn in env_fns]
|
||||
env = self.envs[0]
|
||||
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
||||
obs_space = env.observation_space
|
||||
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
||||
|
||||
self.buf_obs = OrderedDict([
|
||||
(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]))
|
||||
for k in self.keys])
|
||||
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
|
||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||
self.actions = None
|
||||
self.metadata = env.metadata
|
||||
|
||||
def step_async(self, actions):
|
||||
self.actions = actions
|
||||
|
||||
def step_wait(self):
|
||||
for env_idx in range(self.num_envs):
|
||||
obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\
|
||||
self.envs[env_idx].step(self.actions[env_idx])
|
||||
if self.buf_dones[env_idx]:
|
||||
# save final observation where user can get it, then reset
|
||||
self.buf_infos[env_idx]['terminal_observation'] = obs
|
||||
obs = self.envs[env_idx].reset()
|
||||
self._save_obs(env_idx, obs)
|
||||
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
|
||||
self.buf_infos.copy())
|
||||
|
||||
def reset(self):
|
||||
for env_idx in range(self.num_envs):
|
||||
obs = self.envs[env_idx].reset()
|
||||
self._save_obs(env_idx, obs)
|
||||
return self._obs_from_buf()
|
||||
|
||||
def seed(self, seed, indices=None):
|
||||
"""
|
||||
:param seed: (int or [int])
|
||||
:param indices: ([int])
|
||||
"""
|
||||
indices = self._get_indices(indices)
|
||||
if not hasattr(seed, 'len'):
|
||||
seed = [seed] * len(indices)
|
||||
assert len(seed) == len(indices)
|
||||
return [self.envs[i].seed(seed[i]) for i in indices]
|
||||
|
||||
def close(self):
|
||||
for env in self.envs:
|
||||
env.close()
|
||||
|
||||
def get_images(self):
|
||||
return [env.render(mode='rgb_array') for env in self.envs]
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
if self.num_envs == 1:
|
||||
return self.envs[0].render(*args, **kwargs)
|
||||
else:
|
||||
return super().render(*args, **kwargs)
|
||||
|
||||
def _save_obs(self, env_idx, obs):
|
||||
for key in self.keys:
|
||||
if key is None:
|
||||
self.buf_obs[key][env_idx] = obs
|
||||
else:
|
||||
self.buf_obs[key][env_idx] = obs[key]
|
||||
|
||||
def _obs_from_buf(self):
|
||||
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
|
||||
|
||||
def get_attr(self, attr_name, indices=None):
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, attr_name) for env_i in target_envs]
|
||||
|
||||
def set_attr(self, attr_name, value, indices=None):
|
||||
"""Set attribute inside vectorized environments (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
for env_i in target_envs:
|
||||
setattr(env_i, attr_name, value)
|
||||
|
||||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
|
||||
def _get_target_envs(self, indices):
|
||||
indices = self._get_indices(indices)
|
||||
return [self.envs[i] for i in indices]
|
||||
232
torchy_baselines/common/vec_env/subproc_vec_env.py
Normal file
232
torchy_baselines/common/vec_env/subproc_vec_env.py
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
import multiprocessing
|
||||
from collections import OrderedDict
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.vec_env import VecEnv, CloudpickleWrapper
|
||||
|
||||
|
||||
def _worker(remote, parent_remote, env_fn_wrapper):
|
||||
parent_remote.close()
|
||||
env = env_fn_wrapper.var()
|
||||
while True:
|
||||
try:
|
||||
cmd, data = remote.recv()
|
||||
if cmd == 'step':
|
||||
observation, reward, done, info = env.step(data)
|
||||
if done:
|
||||
# save final observation where user can get it, then reset
|
||||
info['terminal_observation'] = observation
|
||||
observation = env.reset()
|
||||
remote.send((observation, reward, done, info))
|
||||
elif cmd == 'reset':
|
||||
observation = env.reset()
|
||||
remote.send(observation)
|
||||
elif cmd == 'render':
|
||||
remote.send(env.render(*data[0], **data[1]))
|
||||
elif cmd == 'close':
|
||||
remote.close()
|
||||
break
|
||||
elif cmd == 'get_spaces':
|
||||
remote.send((env.observation_space, env.action_space))
|
||||
elif cmd == 'env_method':
|
||||
method = getattr(env, data[0])
|
||||
remote.send(method(*data[1], **data[2]))
|
||||
elif cmd == 'get_attr':
|
||||
remote.send(getattr(env, data))
|
||||
elif cmd == 'set_attr':
|
||||
remote.send(setattr(env, data[0], data[1]))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
def tile_images(img_nhwc):
|
||||
"""
|
||||
Tile N images into one big PxQ image
|
||||
(P,Q) are chosen to be as close as possible, and if N
|
||||
is square, then P=Q.
|
||||
|
||||
:param img_nhwc: (list) list or array of images, ndim=4 once turned into array. img nhwc
|
||||
n = batch index, h = height, w = width, c = channel
|
||||
:return: (numpy float) img_HWc, ndim=3
|
||||
"""
|
||||
img_nhwc = np.asarray(img_nhwc)
|
||||
n_images, height, width, n_channels = img_nhwc.shape
|
||||
# new_height was named H before
|
||||
new_height = int(np.ceil(np.sqrt(n_images)))
|
||||
# new_width was named W before
|
||||
new_width = int(np.ceil(float(n_images) / new_height))
|
||||
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
|
||||
# img_HWhwc
|
||||
out_image = img_nhwc.reshape(new_height, new_width, height, width, n_channels)
|
||||
# img_HhWwc
|
||||
out_image = out_image.transpose(0, 2, 1, 3, 4)
|
||||
# img_Hh_Ww_c
|
||||
out_image = out_image.reshape(new_height * height, new_width * width, n_channels)
|
||||
return out_image
|
||||
|
||||
|
||||
|
||||
class SubprocVecEnv(VecEnv):
|
||||
"""
|
||||
Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
|
||||
process, allowing significant speed up when the environment is computationally complex.
|
||||
|
||||
For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
|
||||
number of logical cores on your CPU.
|
||||
|
||||
.. warning::
|
||||
|
||||
Only 'forkserver' and 'spawn' start methods are thread-safe,
|
||||
which is important when TensorFlow sessions or other non thread-safe
|
||||
libraries are used in the parent (see issue #217). However, compared to
|
||||
'fork' they incur a small start-up cost and have restrictions on
|
||||
global variables. With those methods, users must wrap the code in an
|
||||
``if __name__ == "__main__":`` block.
|
||||
For more information, see the multiprocessing documentation.
|
||||
|
||||
:param env_fns: ([Gym Environment]) Environments to run in subprocesses
|
||||
:param start_method: (str) method used to start the subprocesses.
|
||||
Must be one of the methods returned by multiprocessing.get_all_start_methods().
|
||||
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns, start_method=None):
|
||||
self.waiting = False
|
||||
self.closed = False
|
||||
n_envs = len(env_fns)
|
||||
|
||||
if start_method is None:
|
||||
# Fork is not a thread safe method (see issue #217)
|
||||
# but is more user friendly (does not require to wrap the code in
|
||||
# a `if __name__ == "__main__":`)
|
||||
forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods()
|
||||
start_method = 'forkserver' if forkserver_available else 'spawn'
|
||||
ctx = multiprocessing.get_context(start_method)
|
||||
|
||||
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
|
||||
self.processes = []
|
||||
for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
|
||||
args = (work_remote, remote, CloudpickleWrapper(env_fn))
|
||||
# daemon=True: if the main process crashes, we should not cause things to hang
|
||||
process = ctx.Process(target=_worker, args=args, daemon=True)
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
work_remote.close()
|
||||
|
||||
self.remotes[0].send(('get_spaces', None))
|
||||
observation_space, action_space = self.remotes[0].recv()
|
||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||
|
||||
def step_async(self, actions):
|
||||
for remote, action in zip(self.remotes, actions):
|
||||
remote.send(('step', action))
|
||||
self.waiting = True
|
||||
|
||||
def step_wait(self):
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
self.waiting = False
|
||||
obs, rews, dones, infos = zip(*results)
|
||||
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
|
||||
|
||||
def reset(self):
|
||||
for remote in self.remotes:
|
||||
remote.send(('reset', None))
|
||||
obs = [remote.recv() for remote in self.remotes]
|
||||
return _flatten_obs(obs, self.observation_space)
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
if self.waiting:
|
||||
for remote in self.remotes:
|
||||
remote.recv()
|
||||
for remote in self.remotes:
|
||||
remote.send(('close', None))
|
||||
for process in self.processes:
|
||||
process.join()
|
||||
self.closed = True
|
||||
|
||||
def render(self, mode='human', *args, **kwargs):
|
||||
for pipe in self.remotes:
|
||||
# gather images from subprocesses
|
||||
# `mode` will be taken into account later
|
||||
pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs})))
|
||||
imgs = [pipe.recv() for pipe in self.remotes]
|
||||
# Create a big image by tiling images from subprocesses
|
||||
bigimg = tile_images(imgs)
|
||||
if mode == 'human':
|
||||
import cv2
|
||||
cv2.imshow('vecenv', bigimg[:, :, ::-1])
|
||||
cv2.waitKey(1)
|
||||
elif mode == 'rgb_array':
|
||||
return bigimg
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_images(self):
|
||||
for pipe in self.remotes:
|
||||
pipe.send(('render', {"mode": 'rgb_array'}))
|
||||
imgs = [pipe.recv() for pipe in self.remotes]
|
||||
return imgs
|
||||
|
||||
def get_attr(self, attr_name, indices=None):
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_remotes = self._get_target_remotes(indices)
|
||||
for remote in target_remotes:
|
||||
remote.send(('get_attr', attr_name))
|
||||
return [remote.recv() for remote in target_remotes]
|
||||
|
||||
def set_attr(self, attr_name, value, indices=None):
|
||||
"""Set attribute inside vectorized environments (see base class)."""
|
||||
target_remotes = self._get_target_remotes(indices)
|
||||
for remote in target_remotes:
|
||||
remote.send(('set_attr', (attr_name, value)))
|
||||
for remote in target_remotes:
|
||||
remote.recv()
|
||||
|
||||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_remotes = self._get_target_remotes(indices)
|
||||
for remote in target_remotes:
|
||||
remote.send(('env_method', (method_name, method_args, method_kwargs)))
|
||||
return [remote.recv() for remote in target_remotes]
|
||||
|
||||
def _get_target_remotes(self, indices):
|
||||
"""
|
||||
Get the connection object needed to communicate with the wanted
|
||||
envs that are in subprocesses.
|
||||
|
||||
:param indices: (None,int,Iterable) refers to indices of envs.
|
||||
:return: ([multiprocessing.Connection]) Connection object to communicate between processes.
|
||||
"""
|
||||
indices = self._get_indices(indices)
|
||||
return [self.remotes[i] for i in indices]
|
||||
|
||||
|
||||
def _flatten_obs(obs, space):
|
||||
"""
|
||||
Flatten observations, depending on the observation space.
|
||||
|
||||
:param obs: (list<X> or tuple<X> where X is dict<ndarray>, tuple<ndarray> or ndarray) observations.
|
||||
A list or tuple of observations, one per environment.
|
||||
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
|
||||
:return (OrderedDict<ndarray>, tuple<ndarray> or ndarray) flattened observations.
|
||||
A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
|
||||
Each NumPy array has the environment index as its first axis.
|
||||
"""
|
||||
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
|
||||
assert len(obs) > 0, "need observations from at least one environment"
|
||||
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
|
||||
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
|
||||
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
|
||||
obs_len = len(space.spaces)
|
||||
return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len)))
|
||||
else:
|
||||
return np.stack(obs)
|
||||
73
torchy_baselines/common/vec_env/util.py
Normal file
73
torchy_baselines/common/vec_env/util.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""
|
||||
Helpers for dealing with vectorized environments.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
def copy_obs_dict(obs):
|
||||
"""
|
||||
Deep-copy a dict of numpy arrays.
|
||||
|
||||
:param obs: (OrderedDict<ndarray>): a dict of numpy arrays.
|
||||
:return (OrderedDict<ndarray>) a dict of copied numpy arrays.
|
||||
"""
|
||||
assert isinstance(obs, OrderedDict), "unexpected type for observations '{}'".format(type(obs))
|
||||
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
|
||||
|
||||
|
||||
def dict_to_obs(space, obs_dict):
|
||||
"""
|
||||
Convert an internal representation raw_obs into the appropriate type
|
||||
specified by space.
|
||||
|
||||
:param space: (gym.spaces.Space) an observation space.
|
||||
:param obs_dict: (OrderedDict<ndarray>) a dict of numpy arrays.
|
||||
:return (ndarray, tuple<ndarray> or dict<ndarray>): returns an observation
|
||||
of the same type as space. If space is Dict, function is identity;
|
||||
if space is Tuple, converts dict to Tuple; otherwise, space is
|
||||
unstructured and returns the value raw_obs[None].
|
||||
"""
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
return obs_dict
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space"
|
||||
return tuple((obs_dict[i] for i in range(len(space.spaces))))
|
||||
else:
|
||||
assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
|
||||
return obs_dict[None]
|
||||
|
||||
|
||||
def obs_space_info(obs_space):
|
||||
"""
|
||||
Get dict-structured information about a gym.Space.
|
||||
|
||||
Dict spaces are represented directly by their dict of subspaces.
|
||||
Tuple spaces are converted into a dict with keys indexing into the tuple.
|
||||
Unstructured spaces are represented by {None: obs_space}.
|
||||
|
||||
:param obs_space: (gym.spaces.Space) an observation space
|
||||
:return (tuple) A tuple (keys, shapes, dtypes):
|
||||
keys: a list of dict keys.
|
||||
shapes: a dict mapping keys to shapes.
|
||||
dtypes: a dict mapping keys to dtypes.
|
||||
"""
|
||||
if isinstance(obs_space, gym.spaces.Dict):
|
||||
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
|
||||
subspaces = obs_space.spaces
|
||||
elif isinstance(obs_space, gym.spaces.Tuple):
|
||||
subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
|
||||
else:
|
||||
assert not hasattr(obs_space, 'spaces'), "Unsupported structured space '{}'".format(type(obs_space))
|
||||
subspaces = {None: obs_space}
|
||||
keys = []
|
||||
shapes = {}
|
||||
dtypes = {}
|
||||
for key, box in subspaces.items():
|
||||
keys.append(key)
|
||||
shapes[key] = box.shape
|
||||
dtypes[key] = box.dtype
|
||||
return keys, shapes, dtypes
|
||||
55
torchy_baselines/common/vec_env/vec_frame_stack.py
Normal file
55
torchy_baselines/common/vec_env/vec_frame_stack.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from torchy_baselines.common.vec_env import VecEnvWrapper
|
||||
|
||||
|
||||
class VecFrameStack(VecEnvWrapper):
|
||||
"""
|
||||
Frame stacking wrapper for vectorized environment
|
||||
|
||||
:param venv: (VecEnv) the vectorized environment to wrap
|
||||
:param n_stack: (int) Number of frames to stack
|
||||
"""
|
||||
|
||||
def __init__(self, venv, n_stack):
|
||||
self.venv = venv
|
||||
self.n_stack = n_stack
|
||||
wrapped_obs_space = venv.observation_space
|
||||
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
|
||||
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
|
||||
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
|
||||
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
|
||||
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
|
||||
|
||||
def step_wait(self):
|
||||
observations, rewards, dones, infos = self.venv.step_wait()
|
||||
last_ax_size = observations.shape[-1]
|
||||
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
|
||||
for i, done in enumerate(dones):
|
||||
if done:
|
||||
if 'terminal_observation' in infos[i]:
|
||||
old_terminal = infos[i]['terminal_observation']
|
||||
new_terminal = np.concatenate(
|
||||
(self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
|
||||
infos[i]['terminal_observation'] = new_terminal
|
||||
else:
|
||||
warnings.warn(
|
||||
"VecFrameStack wrapping a VecEnv without terminal_observation info")
|
||||
self.stackedobs[i] = 0
|
||||
self.stackedobs[..., -observations.shape[-1]:] = observations
|
||||
return self.stackedobs, rewards, dones, infos
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
obs = self.venv.reset()
|
||||
self.stackedobs[...] = 0
|
||||
self.stackedobs[..., -obs.shape[-1]:] = obs
|
||||
return self.stackedobs
|
||||
|
||||
def close(self):
|
||||
self.venv.close()
|
||||
105
torchy_baselines/common/vec_env/vec_normalize.py
Normal file
105
torchy_baselines/common/vec_env/vec_normalize.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.vec_env import VecEnvWrapper
|
||||
from torchy_baselines.common.running_mean_std import RunningMeanStd
|
||||
|
||||
|
||||
class VecNormalize(VecEnvWrapper):
|
||||
"""
|
||||
A moving average, normalizing wrapper for vectorized environment.
|
||||
has support for saving/loading moving average,
|
||||
|
||||
:param venv: (VecEnv) the vectorized environment to wrap
|
||||
:param training: (bool) Whether to update or not the moving average
|
||||
:param norm_obs: (bool) Whether to normalize observation or not (default: True)
|
||||
:param norm_reward: (bool) Whether to normalize rewards or not (default: True)
|
||||
:param clip_obs: (float) Max absolute value for observation
|
||||
:param clip_reward: (float) Max value absolute for discounted reward
|
||||
:param gamma: (float) discount factor
|
||||
:param epsilon: (float) To avoid division by zero
|
||||
"""
|
||||
|
||||
def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,
|
||||
clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8):
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
||||
self.ret_rms = RunningMeanStd(shape=())
|
||||
self.clip_obs = clip_obs
|
||||
self.clip_reward = clip_reward
|
||||
# Returns: discounted rewards
|
||||
self.ret = np.zeros(self.num_envs)
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self.training = training
|
||||
self.norm_obs = norm_obs
|
||||
self.norm_reward = norm_reward
|
||||
self.old_obs = np.array([])
|
||||
|
||||
def step_wait(self):
|
||||
"""
|
||||
Apply sequence of actions to sequence of environments
|
||||
actions -> (observations, rewards, news)
|
||||
|
||||
where 'news' is a boolean vector indicating whether each element is new.
|
||||
"""
|
||||
obs, rews, news, infos = self.venv.step_wait()
|
||||
self.ret = self.ret * self.gamma + rews
|
||||
self.old_obs = obs
|
||||
obs = self._normalize_observation(obs)
|
||||
if self.norm_reward:
|
||||
if self.training:
|
||||
self.ret_rms.update(self.ret)
|
||||
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
|
||||
self.ret[news] = 0
|
||||
return obs, rews, news, infos
|
||||
|
||||
def _normalize_observation(self, obs):
|
||||
"""
|
||||
:param obs: (numpy tensor)
|
||||
"""
|
||||
if self.norm_obs:
|
||||
if self.training:
|
||||
self.obs_rms.update(obs)
|
||||
obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs,
|
||||
self.clip_obs)
|
||||
return obs
|
||||
else:
|
||||
return obs
|
||||
|
||||
def get_original_obs(self):
|
||||
"""
|
||||
returns the unnormalized observation
|
||||
|
||||
:return: (numpy float)
|
||||
"""
|
||||
return self.old_obs
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
obs = self.venv.reset()
|
||||
if len(np.array(obs).shape) == 1: # for when num_cpu is 1
|
||||
self.old_obs = [obs]
|
||||
else:
|
||||
self.old_obs = obs
|
||||
self.ret = np.zeros(self.num_envs)
|
||||
return self._normalize_observation(obs)
|
||||
|
||||
def save_running_average(self, path):
|
||||
"""
|
||||
:param path: (str) path to log dir
|
||||
"""
|
||||
for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']):
|
||||
with open("{}/{}.pkl".format(path, name), 'wb') as file_handler:
|
||||
pickle.dump(rms, file_handler)
|
||||
|
||||
def load_running_average(self, path):
|
||||
"""
|
||||
:param path: (str) path to log dir
|
||||
"""
|
||||
for name in ['obs_rms', 'ret_rms']:
|
||||
with open("{}/{}.pkl".format(path, name), 'rb') as file_handler:
|
||||
setattr(self, name, pickle.load(file_handler))
|
||||
Loading…
Reference in a new issue