stable-baselines3/tests/test_vec_check_nan.py
Kallinteris Andreas 9c338f917a
vec_envs fix seed() causing a reset (#1486)
* `dummy_vec_env` fix `seed()` causing a reset

* rename `seed`

* fixes

* bug fix

* fix seed return type

* Cleanup seeding, add test and remove compat wrapper

* Update env checker and tests

* Add deterministic test for make_vec_env

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-05-20 10:30:54 +02:00

59 lines
1.4 KiB
Python

import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
class NanAndInfEnv(gym.Env):
"""Custom Environment that raised NaNs and Infs"""
metadata = {"render_modes": ["human"]}
def __init__(self):
super().__init__()
self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
@staticmethod
def step(action):
if np.all(np.array(action) > 0):
obs = float("NaN")
elif np.all(np.array(action) < 0):
obs = float("inf")
else:
obs = 0
return [obs], 0.0, False, False, {}
@staticmethod
def reset(seed=None):
return [0.0], {}
def render(self):
pass
def test_check_nan():
"""Test VecCheckNan Object"""
env = DummyVecEnv([NanAndInfEnv])
env = VecCheckNan(env, raise_exception=True)
env.step([[0]])
with pytest.raises(ValueError):
env.step([[float("NaN")]])
with pytest.raises(ValueError):
env.step([[float("inf")]])
with pytest.raises(ValueError):
env.step([[-1]])
with pytest.raises(ValueError):
env.step([[1]])
env.step(np.array([[0, 1], [0, 1]]))
env.reset()