Update atari test

This commit is contained in:
Antonin RAFFIN 2020-05-07 16:36:48 +02:00
parent e6ff4bbd6c
commit aa0ff8a59b
2 changed files with 13 additions and 2 deletions

View file

@ -9,7 +9,8 @@ import cloudpickle
from stable_baselines3.common import logger
def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray:
def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N

View file

@ -3,6 +3,7 @@ import shutil
import pytest
import gym
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
@ -43,14 +44,23 @@ def test_make_vec_env(env_id, n_envs, wrapper_kwargs):
assert env.num_envs == n_envs
obs = env.reset()
new_obs, reward, _, _ = env.step([env.action_space.sample() for _ in range(n_envs)])
assert obs.shape == new_obs.shape
# Wrapped into DummyVecEnv
wrapped_atari_env = env.envs[0]
if wrapper_kwargs is not None:
# Wrapped into DummyVecEnv + Monitor
assert obs.shape == (n_envs, 60, 60, 1)
assert wrapped_atari_env.observation_space.shape == (60, 60, 1)
assert wrapped_atari_env.clip_reward == False
else:
assert obs.shape == (n_envs, 84, 84, 1)
assert wrapped_atari_env.observation_space.shape == (84, 84, 1)
assert wrapped_atari_env.clip_reward == True
assert np.max(np.abs(reward)) < 1.0