mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
333 lines
11 KiB
Python
333 lines
11 KiB
Python
import collections
|
|
import functools
|
|
import itertools
|
|
import multiprocessing
|
|
|
|
import pytest
|
|
import gym
|
|
import numpy as np
|
|
|
|
from torchy_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecFrameStack
|
|
|
|
N_ENVS = 3
|
|
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
|
|
VEC_ENV_WRAPPERS = [None, VecNormalize, VecFrameStack]
|
|
|
|
|
|
class CustomGymEnv(gym.Env):
|
|
def __init__(self, space):
|
|
"""
|
|
Custom gym environment for testing purposes
|
|
"""
|
|
self.action_space = space
|
|
self.observation_space = space
|
|
self.current_step = 0
|
|
self.ep_length = 4
|
|
|
|
def reset(self):
|
|
self.current_step = 0
|
|
self._choose_next_state()
|
|
return self.state
|
|
|
|
def step(self, action):
|
|
reward = 1
|
|
self._choose_next_state()
|
|
self.current_step += 1
|
|
done = self.current_step >= self.ep_length
|
|
return self.state, reward, done, {}
|
|
|
|
def _choose_next_state(self):
|
|
self.state = self.observation_space.sample()
|
|
|
|
def render(self, mode='human'):
|
|
pass
|
|
|
|
@staticmethod
|
|
def custom_method(dim_0=1, dim_1=1):
|
|
"""
|
|
Dummy method to test call to custom method
|
|
from VecEnv
|
|
|
|
:param dim_0: (int)
|
|
:param dim_1: (int)
|
|
:return: (np.ndarray)
|
|
"""
|
|
return np.ones((dim_0, dim_1))
|
|
|
|
|
|
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
|
|
@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
|
|
def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
|
|
"""Test access to methods/attributes of vectorized environments"""
|
|
|
|
def make_env():
|
|
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
|
|
|
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
|
|
|
|
if vec_env_wrapper is not None:
|
|
if vec_env_wrapper == VecFrameStack:
|
|
vec_env = vec_env_wrapper(vec_env, n_stack=2)
|
|
else:
|
|
vec_env = vec_env_wrapper(vec_env)
|
|
|
|
env_method_results = vec_env.env_method('custom_method', 1, indices=None, dim_1=2)
|
|
setattr_results = []
|
|
# Set current_step to an arbitrary value
|
|
for env_idx in range(N_ENVS):
|
|
setattr_results.append(vec_env.set_attr('current_step', env_idx, indices=env_idx))
|
|
# Retrieve the value for each environment
|
|
getattr_results = vec_env.get_attr('current_step')
|
|
|
|
assert len(env_method_results) == N_ENVS
|
|
assert len(setattr_results) == N_ENVS
|
|
assert len(getattr_results) == N_ENVS
|
|
|
|
for env_idx in range(N_ENVS):
|
|
assert (env_method_results[env_idx] == np.ones((1, 2))).all()
|
|
assert setattr_results[env_idx] is None
|
|
assert getattr_results[env_idx] == env_idx
|
|
|
|
# Call env_method on a subset of the VecEnv
|
|
env_method_subset = vec_env.env_method('custom_method', 1, indices=[0, 2], dim_1=3)
|
|
assert (env_method_subset[0] == np.ones((1, 3))).all()
|
|
assert (env_method_subset[1] == np.ones((1, 3))).all()
|
|
assert len(env_method_subset) == 2
|
|
|
|
# Test to change value for all the environments
|
|
setattr_result = vec_env.set_attr('current_step', 42, indices=None)
|
|
getattr_result = vec_env.get_attr('current_step')
|
|
assert setattr_result is None
|
|
assert getattr_result == [42 for _ in range(N_ENVS)]
|
|
|
|
# Additional tests for setattr that does not affect all the environments
|
|
vec_env.reset()
|
|
setattr_result = vec_env.set_attr('current_step', 12, indices=[0, 1])
|
|
getattr_result = vec_env.get_attr('current_step')
|
|
getattr_result_subset = vec_env.get_attr('current_step', indices=[0, 1])
|
|
assert setattr_result is None
|
|
assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)]
|
|
assert getattr_result_subset == [12, 12]
|
|
assert vec_env.get_attr('current_step', indices=[0, 2]) == [12, 0]
|
|
|
|
vec_env.reset()
|
|
# Change value only for first and last environment
|
|
setattr_result = vec_env.set_attr('current_step', 12, indices=[0, -1])
|
|
getattr_result = vec_env.get_attr('current_step')
|
|
assert setattr_result is None
|
|
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]
|
|
assert vec_env.get_attr('current_step', indices=[-1]) == [12]
|
|
|
|
vec_env.close()
|
|
|
|
|
|
class StepEnv(gym.Env):
|
|
def __init__(self, max_steps):
|
|
"""Gym environment for testing that terminal observation is inserted
|
|
correctly."""
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]),
|
|
dtype='int')
|
|
self.max_steps = max_steps
|
|
self.current_step = 0
|
|
|
|
def reset(self):
|
|
self.current_step = 0
|
|
return np.array([self.current_step], dtype='int')
|
|
|
|
def step(self, action):
|
|
prev_step = self.current_step
|
|
self.current_step += 1
|
|
done = self.current_step >= self.max_steps
|
|
return np.array([prev_step], dtype='int'), 0.0, done, {}
|
|
|
|
|
|
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
|
|
@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
|
|
def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
|
|
"""Test that 'terminal_observation' gets added to info dict upon
|
|
termination."""
|
|
step_nums = [i + 5 for i in range(N_ENVS)]
|
|
vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums])
|
|
|
|
if vec_env_wrapper is not None:
|
|
if vec_env_wrapper == VecFrameStack:
|
|
vec_env = vec_env_wrapper(vec_env, n_stack=2)
|
|
else:
|
|
vec_env = vec_env_wrapper(vec_env)
|
|
|
|
zero_acts = np.zeros((N_ENVS,), dtype='int')
|
|
prev_obs_b = vec_env.reset()
|
|
for step_num in range(1, max(step_nums) + 1):
|
|
obs_b, _, done_b, info_b = vec_env.step(zero_acts)
|
|
assert len(obs_b) == N_ENVS
|
|
assert len(done_b) == N_ENVS
|
|
assert len(info_b) == N_ENVS
|
|
env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums)
|
|
for prev_obs, obs, done, info, final_step_num in env_iter:
|
|
assert done == (step_num == final_step_num)
|
|
if not done:
|
|
assert 'terminal_observation' not in info
|
|
else:
|
|
terminal_obs = info['terminal_observation']
|
|
|
|
# do some rough ordering checks that should work for all
|
|
# wrappers, including VecNormalize
|
|
assert np.all(prev_obs < terminal_obs)
|
|
assert np.all(obs < prev_obs)
|
|
|
|
if not isinstance(vec_env, VecNormalize):
|
|
# more precise tests that we can't do with VecNormalize
|
|
# (which changes observation values)
|
|
assert np.all(prev_obs + 1 == terminal_obs)
|
|
assert np.all(obs == 0)
|
|
|
|
prev_obs_b = obs_b
|
|
|
|
vec_env.close()
|
|
|
|
|
|
SPACES = collections.OrderedDict([
|
|
('discrete', gym.spaces.Discrete(2)),
|
|
('multidiscrete', gym.spaces.MultiDiscrete([2, 3])),
|
|
('multibinary', gym.spaces.MultiBinary(3)),
|
|
('continuous', gym.spaces.Box(low=np.zeros(2), high=np.ones(2))),
|
|
])
|
|
|
|
|
|
def check_vecenv_spaces(vec_env_class, space, obs_assert):
|
|
"""Helper method to check observation spaces in vectorized environments."""
|
|
|
|
def make_env():
|
|
return CustomGymEnv(space)
|
|
|
|
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
|
|
obs = vec_env.reset()
|
|
obs_assert(obs)
|
|
|
|
dones = [False] * N_ENVS
|
|
while not any(dones):
|
|
actions = [vec_env.action_space.sample() for _ in range(N_ENVS)]
|
|
obs, _rews, dones, _infos = vec_env.step(actions)
|
|
obs_assert(obs)
|
|
vec_env.close()
|
|
|
|
|
|
def check_vecenv_obs(obs, space):
|
|
"""Helper method to check observations from multiple environments each belong to
|
|
the appropriate observation space."""
|
|
assert obs.shape[0] == N_ENVS
|
|
for value in obs:
|
|
assert space.contains(value)
|
|
|
|
|
|
@pytest.mark.parametrize('vec_env_class,space', itertools.product(VEC_ENV_CLASSES, SPACES.values()))
|
|
def test_vecenv_single_space(vec_env_class, space):
|
|
def obs_assert(obs):
|
|
return check_vecenv_obs(obs, space)
|
|
|
|
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
|
|
class _UnorderedDictSpace(gym.spaces.Dict):
|
|
"""Like DictSpace, but returns an unordered dict when sampling."""
|
|
|
|
def sample(self):
|
|
return dict(super().sample())
|
|
|
|
|
|
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
|
|
def test_vecenv_dict_spaces(vec_env_class):
|
|
"""Test dictionary observation spaces with vectorized environments."""
|
|
space = gym.spaces.Dict(SPACES)
|
|
|
|
def obs_assert(obs):
|
|
assert isinstance(obs, collections.OrderedDict)
|
|
assert obs.keys() == space.spaces.keys()
|
|
for key, values in obs.items():
|
|
check_vecenv_obs(values, space.spaces[key])
|
|
|
|
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
unordered_space = _UnorderedDictSpace(SPACES)
|
|
# Check that vec_env_class can accept unordered dict observations (and convert to OrderedDict)
|
|
check_vecenv_spaces(vec_env_class, unordered_space, obs_assert)
|
|
|
|
|
|
@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
|
|
def test_vecenv_tuple_spaces(vec_env_class):
|
|
"""Test tuple observation spaces with vectorized environments."""
|
|
space = gym.spaces.Tuple(tuple(SPACES.values()))
|
|
|
|
def obs_assert(obs):
|
|
assert isinstance(obs, tuple)
|
|
assert len(obs) == len(space.spaces)
|
|
for values, inner_space in zip(obs, space.spaces):
|
|
check_vecenv_obs(values, inner_space)
|
|
|
|
return check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
|
|
def test_subproc_start_method():
|
|
start_methods = [None]
|
|
# Only test thread-safe methods. Others may deadlock tests! (gh/428)
|
|
safe_methods = {'forkserver', 'spawn'}
|
|
available_methods = multiprocessing.get_all_start_methods()
|
|
start_methods += list(safe_methods.intersection(available_methods))
|
|
space = gym.spaces.Discrete(2)
|
|
|
|
def obs_assert(obs):
|
|
return check_vecenv_obs(obs, space)
|
|
|
|
for start_method in start_methods:
|
|
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)
|
|
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"):
|
|
vec_env_class = functools.partial(SubprocVecEnv, start_method='illegal_method')
|
|
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
|
|
class CustomWrapperA(VecNormalize):
|
|
def __init__(self, venv):
|
|
VecNormalize.__init__(self, venv)
|
|
self.var_a = 'a'
|
|
|
|
|
|
class CustomWrapperB(VecNormalize):
|
|
def __init__(self, venv):
|
|
VecNormalize.__init__(self, venv)
|
|
self.var_b = 'b'
|
|
|
|
def func_b(self):
|
|
return self.var_b
|
|
|
|
def name_test(self):
|
|
return self.__class__
|
|
|
|
|
|
class CustomWrapperBB(CustomWrapperB):
|
|
def __init__(self, venv):
|
|
CustomWrapperB.__init__(self, venv)
|
|
self.var_bb = 'bb'
|
|
|
|
|
|
def test_vecenv_wrapper_getattr():
|
|
def make_env():
|
|
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
|
|
|
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
|
|
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
|
|
assert wrapped.var_a == 'a'
|
|
assert wrapped.var_b == 'b'
|
|
assert wrapped.var_bb == 'bb'
|
|
assert wrapped.func_b() == 'b'
|
|
assert wrapped.name_test() == CustomWrapperBB
|
|
|
|
double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
|
|
dummy = double_wrapped.var_a # should not raise as it is directly defined here
|
|
with pytest.raises(AttributeError): # should raise due to ambiguity
|
|
dummy = double_wrapped.var_b
|
|
with pytest.raises(AttributeError): # should raise as does not exist
|
|
dummy = double_wrapped.nonexistent_attribute
|
|
del dummy # keep linter happy
|