mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
* Drop python 3.8 support, add python 3.12 support * Upgrade to python 3.9 syntax * Fixes for Numpy v2 * Fix doc warning
193 lines
6.1 KiB
Python
193 lines
6.1 KiB
Python
from typing import Any, Optional
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import pytest
|
|
from gymnasium import spaces
|
|
|
|
from stable_baselines3.common.env_checker import check_env
|
|
|
|
|
|
class ActionDictTestEnv(gym.Env):
|
|
metadata = {"render_modes": ["human"]}
|
|
render_mode = None
|
|
|
|
action_space = spaces.Dict({"position": spaces.Discrete(1), "velocity": spaces.Discrete(1)})
|
|
observation_space = spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)
|
|
|
|
def step(self, action):
|
|
observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype)
|
|
reward = 1
|
|
terminated = True
|
|
truncated = False
|
|
info = {}
|
|
return observation, reward, terminated, truncated, info
|
|
|
|
def reset(self, seed=None):
|
|
return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {}
|
|
|
|
def render(self):
|
|
pass
|
|
|
|
|
|
def test_check_env_dict_action():
|
|
test_env = ActionDictTestEnv()
|
|
|
|
with pytest.warns(Warning):
|
|
check_env(env=test_env, warn=True)
|
|
|
|
|
|
class SequenceObservationEnv(gym.Env):
|
|
metadata = {"render_modes": [], "render_fps": 2}
|
|
|
|
def __init__(self, render_mode=None):
|
|
self.observation_space = spaces.Sequence(spaces.Discrete(8))
|
|
self.action_space = spaces.Discrete(4)
|
|
|
|
def reset(self, seed=None, options=None):
|
|
super().reset(seed=seed)
|
|
return self.observation_space.sample(), {}
|
|
|
|
def step(self, action):
|
|
return self.observation_space.sample(), 1.0, False, False, {}
|
|
|
|
|
|
def test_check_env_sequence_obs():
|
|
test_env = SequenceObservationEnv()
|
|
|
|
with pytest.warns(Warning, match="Sequence.*not supported"):
|
|
check_env(env=test_env, warn=True)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"obs_tuple",
|
|
[
|
|
# Above upper bound
|
|
(
|
|
spaces.Box(low=np.array([0.0, 0.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32),
|
|
np.array([1.0, 1.5, 0.5], dtype=np.float32),
|
|
r"Expected: 0\.0 <= obs\[1] <= 1\.0, actual value: 1\.5",
|
|
),
|
|
# Above upper bound (multi-dim)
|
|
(
|
|
spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32),
|
|
3.0 * np.ones((2, 3, 3, 1), dtype=np.float32),
|
|
# Note: this is one of the 18 invalid indices
|
|
r"Expected: -1\.0 <= obs\[1,2,1,0\] <= 2\.0, actual value: 3\.0",
|
|
),
|
|
# Below lower bound
|
|
(
|
|
spaces.Box(low=np.array([0.0, -10.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32),
|
|
np.array([-1.0, 1.5, 0.5], dtype=np.float32),
|
|
r"Expected: 0\.0 <= obs\[0] <= 2\.0, actual value: -1\.0",
|
|
),
|
|
# Below lower bound (multi-dim)
|
|
(
|
|
spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32),
|
|
-2 * np.ones((2, 3, 3, 1), dtype=np.float32),
|
|
r"18 invalid indices:",
|
|
),
|
|
# Wrong dtype
|
|
(
|
|
spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32),
|
|
np.array([1.0, 1.5, 0.5], dtype=np.float64),
|
|
r"Expected: float32, actual dtype: float64",
|
|
),
|
|
# Wrong shape
|
|
(
|
|
spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32),
|
|
np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32),
|
|
r"Expected: \(3,\), actual shape: \(2, 3\)",
|
|
),
|
|
# Wrong shape (dict obs)
|
|
(
|
|
spaces.Dict({"obs": spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)}),
|
|
{"obs": np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32)},
|
|
r"Error while checking key=obs.*Expected: \(3,\), actual shape: \(2, 3\)",
|
|
),
|
|
# Wrong shape (multi discrete)
|
|
(
|
|
spaces.MultiDiscrete([3, 3]),
|
|
np.array([[2, 0]]),
|
|
r"Expected: \(2,\), actual shape: \(1, 2\)",
|
|
),
|
|
# Wrong shape (multi binary)
|
|
(
|
|
spaces.MultiBinary(3),
|
|
np.array([[1, 0, 0]]),
|
|
r"Expected: \(3,\), actual shape: \(1, 3\)",
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
# Check when it happens at reset or during step
|
|
"method",
|
|
["reset", "step"],
|
|
)
|
|
def test_check_env_detailed_error(obs_tuple, method):
|
|
"""
|
|
Check that the env checker returns more detail error
|
|
when the observation is not in the obs space.
|
|
"""
|
|
observation_space, wrong_obs, error_message = obs_tuple
|
|
good_obs = observation_space.sample()
|
|
|
|
class TestEnv(gym.Env):
|
|
action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
|
|
|
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
|
return wrong_obs if method == "reset" else good_obs, {}
|
|
|
|
def step(self, action):
|
|
obs = wrong_obs if method == "step" else good_obs
|
|
return obs, 0.0, True, False, {}
|
|
|
|
TestEnv.observation_space = observation_space
|
|
|
|
test_env = TestEnv()
|
|
with pytest.raises(AssertionError, match=error_message):
|
|
check_env(env=test_env, warn=False)
|
|
|
|
|
|
class LimitedStepsTestEnv(gym.Env):
|
|
action_space = spaces.Discrete(n=2)
|
|
observation_space = spaces.Discrete(n=2)
|
|
|
|
def __init__(self, steps_before_termination: int = 1):
|
|
super().__init__()
|
|
|
|
assert steps_before_termination >= 1
|
|
self._steps_before_termination = steps_before_termination
|
|
|
|
self._steps_called = 0
|
|
self._terminated = False
|
|
|
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[int, dict]:
|
|
super().reset(seed=seed)
|
|
|
|
self._steps_called = 0
|
|
self._terminated = False
|
|
|
|
return 0, {}
|
|
|
|
def step(self, action: np.ndarray) -> tuple[int, float, bool, bool, dict[str, Any]]:
|
|
self._steps_called += 1
|
|
|
|
assert not self._terminated
|
|
|
|
observation = 0
|
|
reward = 0.0
|
|
self._terminated = self._steps_called >= self._steps_before_termination
|
|
truncated = False
|
|
|
|
return observation, reward, self._terminated, truncated, {}
|
|
|
|
def render(self) -> None:
|
|
pass
|
|
|
|
|
|
def test_check_env_single_step_env():
|
|
test_env = LimitedStepsTestEnv(steps_before_termination=1)
|
|
|
|
# This should not throw
|
|
check_env(env=test_env, warn=True)
|