mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
* Add test for channel-first environments * Add support for channel-first envs, including more tests * Update changelog * Run black * Run black, again * Improve NatureCNN error message * Update image checks and FrameStack wrapper * Update tests * Update docs * Run isort * Reformat * Fixes: avoid breaking changes for non-image env * Add additional checks * Update docstring Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
417 lines
14 KiB
Python
417 lines
14 KiB
Python
import collections
|
|
import functools
|
|
import itertools
|
|
import multiprocessing
|
|
|
|
import gym
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
|
|
|
|
N_ENVS = 3
|
|
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
|
|
VEC_ENV_WRAPPERS = [None, VecNormalize, VecFrameStack]
|
|
|
|
|
|
class CustomGymEnv(gym.Env):
|
|
def __init__(self, space):
|
|
"""
|
|
Custom gym environment for testing purposes
|
|
"""
|
|
self.action_space = space
|
|
self.observation_space = space
|
|
self.current_step = 0
|
|
self.ep_length = 4
|
|
|
|
def reset(self):
|
|
self.current_step = 0
|
|
self._choose_next_state()
|
|
return self.state
|
|
|
|
def step(self, action):
|
|
reward = 1
|
|
self._choose_next_state()
|
|
self.current_step += 1
|
|
done = self.current_step >= self.ep_length
|
|
return self.state, reward, done, {}
|
|
|
|
def _choose_next_state(self):
|
|
self.state = self.observation_space.sample()
|
|
|
|
def render(self, mode="human"):
|
|
if mode == "rgb_array":
|
|
return np.zeros((4, 4, 3))
|
|
|
|
def seed(self, seed=None):
|
|
pass
|
|
|
|
@staticmethod
|
|
def custom_method(dim_0=1, dim_1=1):
|
|
"""
|
|
Dummy method to test call to custom method
|
|
from VecEnv
|
|
|
|
:param dim_0: (int)
|
|
:param dim_1: (int)
|
|
:return: (np.ndarray)
|
|
"""
|
|
return np.ones((dim_0, dim_1))
|
|
|
|
|
|
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
|
|
@pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS)
|
|
def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
|
|
"""Test access to methods/attributes of vectorized environments"""
|
|
|
|
def make_env():
|
|
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
|
|
|
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
|
|
|
|
if vec_env_wrapper is not None:
|
|
if vec_env_wrapper == VecFrameStack:
|
|
vec_env = vec_env_wrapper(vec_env, n_stack=2)
|
|
else:
|
|
vec_env = vec_env_wrapper(vec_env)
|
|
|
|
# Test seed method
|
|
vec_env.seed(0)
|
|
# Test render method call
|
|
# vec_env.render() # we need a X server to test the "human" mode
|
|
vec_env.render(mode="rgb_array")
|
|
env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2)
|
|
setattr_results = []
|
|
# Set current_step to an arbitrary value
|
|
for env_idx in range(N_ENVS):
|
|
setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx))
|
|
# Retrieve the value for each environment
|
|
getattr_results = vec_env.get_attr("current_step")
|
|
|
|
assert len(env_method_results) == N_ENVS
|
|
assert len(setattr_results) == N_ENVS
|
|
assert len(getattr_results) == N_ENVS
|
|
|
|
for env_idx in range(N_ENVS):
|
|
assert (env_method_results[env_idx] == np.ones((1, 2))).all()
|
|
assert setattr_results[env_idx] is None
|
|
assert getattr_results[env_idx] == env_idx
|
|
|
|
# Call env_method on a subset of the VecEnv
|
|
env_method_subset = vec_env.env_method("custom_method", 1, indices=[0, 2], dim_1=3)
|
|
assert (env_method_subset[0] == np.ones((1, 3))).all()
|
|
assert (env_method_subset[1] == np.ones((1, 3))).all()
|
|
assert len(env_method_subset) == 2
|
|
|
|
# Test to change value for all the environments
|
|
setattr_result = vec_env.set_attr("current_step", 42, indices=None)
|
|
getattr_result = vec_env.get_attr("current_step")
|
|
assert setattr_result is None
|
|
assert getattr_result == [42 for _ in range(N_ENVS)]
|
|
|
|
# Additional tests for setattr that does not affect all the environments
|
|
vec_env.reset()
|
|
setattr_result = vec_env.set_attr("current_step", 12, indices=[0, 1])
|
|
getattr_result = vec_env.get_attr("current_step")
|
|
getattr_result_subset = vec_env.get_attr("current_step", indices=[0, 1])
|
|
assert setattr_result is None
|
|
assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)]
|
|
assert getattr_result_subset == [12, 12]
|
|
assert vec_env.get_attr("current_step", indices=[0, 2]) == [12, 0]
|
|
|
|
vec_env.reset()
|
|
# Change value only for first and last environment
|
|
setattr_result = vec_env.set_attr("current_step", 12, indices=[0, -1])
|
|
getattr_result = vec_env.get_attr("current_step")
|
|
assert setattr_result is None
|
|
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]
|
|
assert vec_env.get_attr("current_step", indices=[-1]) == [12]
|
|
|
|
vec_env.close()
|
|
|
|
|
|
class StepEnv(gym.Env):
|
|
def __init__(self, max_steps):
|
|
"""Gym environment for testing that terminal observation is inserted
|
|
correctly."""
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]), dtype="int")
|
|
self.max_steps = max_steps
|
|
self.current_step = 0
|
|
|
|
def reset(self):
|
|
self.current_step = 0
|
|
return np.array([self.current_step], dtype="int")
|
|
|
|
def step(self, action):
|
|
prev_step = self.current_step
|
|
self.current_step += 1
|
|
done = self.current_step >= self.max_steps
|
|
return np.array([prev_step], dtype="int"), 0.0, done, {}
|
|
|
|
|
|
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
|
|
@pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS)
|
|
def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
|
|
"""Test that 'terminal_observation' gets added to info dict upon
|
|
termination."""
|
|
step_nums = [i + 5 for i in range(N_ENVS)]
|
|
vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums])
|
|
|
|
if vec_env_wrapper is not None:
|
|
if vec_env_wrapper == VecFrameStack:
|
|
vec_env = vec_env_wrapper(vec_env, n_stack=2)
|
|
else:
|
|
vec_env = vec_env_wrapper(vec_env)
|
|
|
|
zero_acts = np.zeros((N_ENVS,), dtype="int")
|
|
prev_obs_b = vec_env.reset()
|
|
for step_num in range(1, max(step_nums) + 1):
|
|
obs_b, _, done_b, info_b = vec_env.step(zero_acts)
|
|
assert len(obs_b) == N_ENVS
|
|
assert len(done_b) == N_ENVS
|
|
assert len(info_b) == N_ENVS
|
|
env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums)
|
|
for prev_obs, obs, done, info, final_step_num in env_iter:
|
|
assert done == (step_num == final_step_num)
|
|
if not done:
|
|
assert "terminal_observation" not in info
|
|
else:
|
|
terminal_obs = info["terminal_observation"]
|
|
|
|
# do some rough ordering checks that should work for all
|
|
# wrappers, including VecNormalize
|
|
assert np.all(prev_obs < terminal_obs)
|
|
assert np.all(obs < prev_obs)
|
|
|
|
if not isinstance(vec_env, VecNormalize):
|
|
# more precise tests that we can't do with VecNormalize
|
|
# (which changes observation values)
|
|
assert np.all(prev_obs + 1 == terminal_obs)
|
|
assert np.all(obs == 0)
|
|
|
|
prev_obs_b = obs_b
|
|
|
|
vec_env.close()
|
|
|
|
|
|
SPACES = collections.OrderedDict(
|
|
[
|
|
("discrete", gym.spaces.Discrete(2)),
|
|
("multidiscrete", gym.spaces.MultiDiscrete([2, 3])),
|
|
("multibinary", gym.spaces.MultiBinary(3)),
|
|
("continuous", gym.spaces.Box(low=np.zeros(2), high=np.ones(2))),
|
|
]
|
|
)
|
|
|
|
|
|
def check_vecenv_spaces(vec_env_class, space, obs_assert):
|
|
"""Helper method to check observation spaces in vectorized environments."""
|
|
|
|
def make_env():
|
|
return CustomGymEnv(space)
|
|
|
|
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
|
|
obs = vec_env.reset()
|
|
obs_assert(obs)
|
|
|
|
dones = [False] * N_ENVS
|
|
while not any(dones):
|
|
actions = [vec_env.action_space.sample() for _ in range(N_ENVS)]
|
|
obs, _rews, dones, _infos = vec_env.step(actions)
|
|
obs_assert(obs)
|
|
vec_env.close()
|
|
|
|
|
|
def check_vecenv_obs(obs, space):
|
|
"""Helper method to check observations from multiple environments each belong to
|
|
the appropriate observation space."""
|
|
assert obs.shape[0] == N_ENVS
|
|
for value in obs:
|
|
assert space.contains(value)
|
|
|
|
|
|
@pytest.mark.parametrize("vec_env_class,space", itertools.product(VEC_ENV_CLASSES, SPACES.values()))
|
|
def test_vecenv_single_space(vec_env_class, space):
|
|
def obs_assert(obs):
|
|
return check_vecenv_obs(obs, space)
|
|
|
|
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
|
|
class _UnorderedDictSpace(gym.spaces.Dict):
|
|
"""Like DictSpace, but returns an unordered dict when sampling."""
|
|
|
|
def sample(self):
|
|
return dict(super().sample())
|
|
|
|
|
|
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
|
|
def test_vecenv_dict_spaces(vec_env_class):
|
|
"""Test dictionary observation spaces with vectorized environments."""
|
|
space = gym.spaces.Dict(SPACES)
|
|
|
|
def obs_assert(obs):
|
|
assert isinstance(obs, collections.OrderedDict)
|
|
assert obs.keys() == space.spaces.keys()
|
|
for key, values in obs.items():
|
|
check_vecenv_obs(values, space.spaces[key])
|
|
|
|
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
unordered_space = _UnorderedDictSpace(SPACES)
|
|
# Check that vec_env_class can accept unordered dict observations (and convert to OrderedDict)
|
|
check_vecenv_spaces(vec_env_class, unordered_space, obs_assert)
|
|
|
|
|
|
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
|
|
def test_vecenv_tuple_spaces(vec_env_class):
|
|
"""Test tuple observation spaces with vectorized environments."""
|
|
space = gym.spaces.Tuple(tuple(SPACES.values()))
|
|
|
|
def obs_assert(obs):
|
|
assert isinstance(obs, tuple)
|
|
assert len(obs) == len(space.spaces)
|
|
for values, inner_space in zip(obs, space.spaces):
|
|
check_vecenv_obs(values, inner_space)
|
|
|
|
return check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
|
|
def test_subproc_start_method():
|
|
start_methods = [None]
|
|
# Only test thread-safe methods. Others may deadlock tests! (gh/428)
|
|
# Note: adding unsafe `fork` method as we are now using PyTorch
|
|
all_methods = {"forkserver", "spawn", "fork"}
|
|
available_methods = multiprocessing.get_all_start_methods()
|
|
start_methods += list(all_methods.intersection(available_methods))
|
|
space = gym.spaces.Discrete(2)
|
|
|
|
def obs_assert(obs):
|
|
return check_vecenv_obs(obs, space)
|
|
|
|
for start_method in start_methods:
|
|
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)
|
|
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"):
|
|
vec_env_class = functools.partial(SubprocVecEnv, start_method="illegal_method")
|
|
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
|
|
|
|
|
class CustomWrapperA(VecNormalize):
|
|
def __init__(self, venv):
|
|
VecNormalize.__init__(self, venv)
|
|
self.var_a = "a"
|
|
|
|
|
|
class CustomWrapperB(VecNormalize):
|
|
def __init__(self, venv):
|
|
VecNormalize.__init__(self, venv)
|
|
self.var_b = "b"
|
|
|
|
def func_b(self):
|
|
return self.var_b
|
|
|
|
def name_test(self):
|
|
return self.__class__
|
|
|
|
|
|
class CustomWrapperBB(CustomWrapperB):
|
|
def __init__(self, venv):
|
|
CustomWrapperB.__init__(self, venv)
|
|
self.var_bb = "bb"
|
|
|
|
|
|
def test_vecenv_wrapper_getattr():
|
|
def make_env():
|
|
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
|
|
|
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
|
|
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
|
|
assert wrapped.var_a == "a"
|
|
assert wrapped.var_b == "b"
|
|
assert wrapped.var_bb == "bb"
|
|
assert wrapped.func_b() == "b"
|
|
assert wrapped.name_test() == CustomWrapperBB
|
|
|
|
double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
|
|
_ = double_wrapped.var_a # should not raise as it is directly defined here
|
|
with pytest.raises(AttributeError): # should raise due to ambiguity
|
|
_ = double_wrapped.var_b
|
|
with pytest.raises(AttributeError): # should raise as does not exist
|
|
_ = double_wrapped.nonexistent_attribute
|
|
|
|
|
|
def test_framestack_vecenv():
|
|
"""Test that framestack environment stacks on desired axis"""
|
|
|
|
image_space_shape = [12, 8, 3]
|
|
zero_acts = np.zeros([N_ENVS] + image_space_shape)
|
|
|
|
transposed_image_space_shape = image_space_shape[::-1]
|
|
transposed_zero_acts = np.zeros([N_ENVS] + transposed_image_space_shape)
|
|
|
|
def make_image_env():
|
|
return CustomGymEnv(
|
|
gym.spaces.Box(
|
|
low=np.zeros(image_space_shape),
|
|
high=np.ones(image_space_shape) * 255,
|
|
dtype=np.uint8,
|
|
)
|
|
)
|
|
|
|
def make_transposed_image_env():
|
|
return CustomGymEnv(
|
|
gym.spaces.Box(
|
|
low=np.zeros(transposed_image_space_shape),
|
|
high=np.ones(transposed_image_space_shape) * 255,
|
|
dtype=np.uint8,
|
|
)
|
|
)
|
|
|
|
def make_non_image_env():
|
|
return CustomGymEnv(gym.spaces.Box(low=np.zeros((2,)), high=np.ones((2,))))
|
|
|
|
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
|
|
vec_env = VecFrameStack(vec_env, n_stack=2)
|
|
obs, _, _, _ = vec_env.step(zero_acts)
|
|
vec_env.close()
|
|
|
|
# Should be stacked on the last dimension
|
|
assert obs.shape[-1] == (image_space_shape[-1] * 2)
|
|
|
|
# Try automatic stacking on first dimension now
|
|
vec_env = DummyVecEnv([make_transposed_image_env for _ in range(N_ENVS)])
|
|
vec_env = VecFrameStack(vec_env, n_stack=2)
|
|
obs, _, _, _ = vec_env.step(transposed_zero_acts)
|
|
vec_env.close()
|
|
|
|
# Should be stacked on the first dimension (note the transposing in make_transposed_image_env)
|
|
assert obs.shape[1] == (image_space_shape[-1] * 2)
|
|
|
|
# Try forcing dimensions
|
|
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
|
|
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="last")
|
|
obs, _, _, _ = vec_env.step(zero_acts)
|
|
vec_env.close()
|
|
|
|
# Should be stacked on the last dimension
|
|
assert obs.shape[-1] == (image_space_shape[-1] * 2)
|
|
|
|
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
|
|
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="first")
|
|
obs, _, _, _ = vec_env.step(zero_acts)
|
|
vec_env.close()
|
|
|
|
# Should be stacked on the first dimension
|
|
assert obs.shape[1] == (image_space_shape[0] * 2)
|
|
|
|
# Test invalid channels_order
|
|
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
|
|
with pytest.raises(AssertionError):
|
|
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="not_valid")
|
|
|
|
# Test that it works with non-image envs when no channels_order is given
|
|
vec_env = DummyVecEnv([make_non_image_env for _ in range(N_ENVS)])
|
|
vec_env = VecFrameStack(vec_env, n_stack=2)
|