2022-12-08 17:46:41 +00:00
|
|
|
import torch
|
2023-04-14 11:13:59 +00:00
|
|
|
from gymnasium import spaces
|
2022-12-08 17:46:41 +00:00
|
|
|
|
|
|
|
|
from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_obs_shape_discrete():
|
|
|
|
|
assert get_obs_shape(spaces.Discrete(3)) == (1,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_obs_shape_multidiscrete():
|
|
|
|
|
assert get_obs_shape(spaces.MultiDiscrete([3, 2])) == (2,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_obs_shape_multibinary():
|
|
|
|
|
assert get_obs_shape(spaces.MultiBinary(3)) == (3,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_obs_shape_multidimensional_multibinary():
|
|
|
|
|
assert get_obs_shape(spaces.MultiBinary([3, 2])) == (3, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_obs_shape_box():
|
|
|
|
|
assert get_obs_shape(spaces.Box(-2, 2, shape=(3,))) == (3,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_obs_shape_multidimensional_box():
|
|
|
|
|
assert get_obs_shape(spaces.Box(-2, 2, shape=(3, 2))) == (3, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_preprocess_obs_discrete():
|
|
|
|
|
actual = preprocess_obs(torch.tensor([2], dtype=torch.long), spaces.Discrete(3))
|
|
|
|
|
expected = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32)
|
|
|
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_preprocess_obs_multidiscrete():
|
|
|
|
|
actual = preprocess_obs(torch.tensor([[2, 0]], dtype=torch.long), spaces.MultiDiscrete([3, 2]))
|
|
|
|
|
expected = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32)
|
|
|
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_preprocess_obs_multibinary():
|
|
|
|
|
actual = preprocess_obs(torch.tensor([[1, 0, 1]], dtype=torch.long), spaces.MultiBinary(3))
|
|
|
|
|
expected = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
|
|
|
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_preprocess_obs_multidimensional_multibinary():
|
|
|
|
|
actual = preprocess_obs(torch.tensor([[[1, 0], [1, 1], [0, 1]]], dtype=torch.long), spaces.MultiBinary([3, 2]))
|
|
|
|
|
expected = torch.tensor([[[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=torch.float32)
|
|
|
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_preprocess_obs_box():
|
|
|
|
|
actual = preprocess_obs(torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3,)))
|
|
|
|
|
expected = torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32)
|
|
|
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_preprocess_obs_multidimensional_box():
|
|
|
|
|
actual = preprocess_obs(
|
|
|
|
|
torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3, 2))
|
|
|
|
|
)
|
|
|
|
|
expected = torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32)
|
|
|
|
|
torch.testing.assert_close(actual, expected)
|