diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index ce03391..adf2ec0 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -9,7 +9,8 @@ import cloudpickle from stable_baselines3.common import logger -def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: + +def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover """ Tile N images into one big PxQ image (P,Q) are chosen to be as close as possible, and if N diff --git a/tests/test_utils.py b/tests/test_utils.py index 9386cb3..5c029ec 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ import shutil import pytest import gym +import numpy as np from stable_baselines3 import A2C from stable_baselines3.common.monitor import Monitor @@ -43,14 +44,23 @@ def test_make_vec_env(env_id, n_envs, wrapper_kwargs): assert env.num_envs == n_envs + obs = env.reset() + + new_obs, reward, _, _ = env.step([env.action_space.sample() for _ in range(n_envs)]) + + assert obs.shape == new_obs.shape + + # Wrapped into DummyVecEnv wrapped_atari_env = env.envs[0] if wrapper_kwargs is not None: - # Wrapped into DummyVecEnv + Monitor + assert obs.shape == (n_envs, 60, 60, 1) assert wrapped_atari_env.observation_space.shape == (60, 60, 1) assert wrapped_atari_env.clip_reward == False else: + assert obs.shape == (n_envs, 84, 84, 1) assert wrapped_atari_env.observation_space.shape == (84, 84, 1) assert wrapped_atari_env.clip_reward == True + assert np.max(np.abs(reward)) < 1.0