stable-baselines3/tests/test_vec_check_nan.py
Anssi b833207142
Add some missing tests, update VecNormalize and RolloutBuffer (#50)
* Change saving/loading normalization parameters to use single pickle file

* Remove 'use_gae' from RolloutBuffer compute_returns function

* Add some missing tests for normalizer, nan-checker and PPO clip_value_fn argument

* Update changelog

* Fix typo

* Use proper pytest.raises for catching errors in tests

* Add comment on GAE and how to obtain non-GAE behaviour

* Remove save/load_running_average from VecNormalize in favor of load/save

* Update changelog

* Update docstring

* Add accidentally removed tests for VecNormalize

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-10 12:09:04 +02:00

58 lines
1.4 KiB
Python

import gym
from gym import spaces
import numpy as np
import pytest
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
class NanAndInfEnv(gym.Env):
"""Custom Environment that raised NaNs and Infs"""
metadata = {'render.modes': ['human']}
def __init__(self):
super(NanAndInfEnv, self).__init__()
self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
@staticmethod
def step(action):
if np.all(np.array(action) > 0):
obs = float('NaN')
elif np.all(np.array(action) < 0):
obs = float('inf')
else:
obs = 0
return [obs], 0.0, False, {}
@staticmethod
def reset():
return [0.0]
def render(self, mode='human', close=False):
pass
def test_check_nan():
"""Test VecCheckNan Object"""
env = DummyVecEnv([NanAndInfEnv])
env = VecCheckNan(env, raise_exception=True)
env.step([[0]])
with pytest.raises(ValueError):
env.step([[float('NaN')]])
with pytest.raises(ValueError):
env.step([[float('inf')]])
with pytest.raises(ValueError):
env.step([[-1]])
with pytest.raises(ValueError):
env.step([[1]])
env.step(np.array([[0, 1], [0, 1]]))
env.reset()