mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
Update atari test
This commit is contained in:
parent
e6ff4bbd6c
commit
aa0ff8a59b
2 changed files with 13 additions and 2 deletions
|
|
@ -9,7 +9,8 @@ import cloudpickle
|
|||
from stable_baselines3.common import logger
|
||||
|
||||
|
||||
def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray:
|
||||
|
||||
def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
|
||||
"""
|
||||
Tile N images into one big PxQ image
|
||||
(P,Q) are chosen to be as close as possible, and if N
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import shutil
|
|||
|
||||
import pytest
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
|
|
@ -43,14 +44,23 @@ def test_make_vec_env(env_id, n_envs, wrapper_kwargs):
|
|||
|
||||
assert env.num_envs == n_envs
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
new_obs, reward, _, _ = env.step([env.action_space.sample() for _ in range(n_envs)])
|
||||
|
||||
assert obs.shape == new_obs.shape
|
||||
|
||||
# Wrapped into DummyVecEnv
|
||||
wrapped_atari_env = env.envs[0]
|
||||
if wrapper_kwargs is not None:
|
||||
# Wrapped into DummyVecEnv + Monitor
|
||||
assert obs.shape == (n_envs, 60, 60, 1)
|
||||
assert wrapped_atari_env.observation_space.shape == (60, 60, 1)
|
||||
assert wrapped_atari_env.clip_reward == False
|
||||
else:
|
||||
assert obs.shape == (n_envs, 84, 84, 1)
|
||||
assert wrapped_atari_env.observation_space.shape == (84, 84, 1)
|
||||
assert wrapped_atari_env.clip_reward == True
|
||||
assert np.max(np.abs(reward)) < 1.0
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue