From ba77dd7c6180c0ec9a47dfa98291c2103e6750df Mon Sep 17 00:00:00 2001 From: Tobias Rohrer Date: Mon, 24 Jul 2023 16:38:03 +0200 Subject: [PATCH] Fix to use float64 actions for off policy algorithms (#1572) * Added test cases where off policy algorithms fail with float64 actionspace * casting observations and actions to `np.float32` to unify behaviour between `ReplayBuffer` and `RolloutBuffer`. Fixing issue #1145 * reformatted using black * making test more restrictive by checking models action is float64 * added changelog entry * undo cast of observations as `preprocessing.preprocess_obs()` casts them to float32 anyways. * - Casting to float32 only, if action.dtype is float64 - Added cast to `DictReplayBuffer` as well * Added tests for multiple variations of continuous action types and observation spaces * applied reformatting by `make commit-checks` * Added typing and comment referring to description in merge request * Apply linter for single element slice * Rename helper and refactor tests * Update changelog and docstring --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 4 +- stable_baselines3/common/buffers.py | 23 +++++- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/version.txt | 2 +- tests/test_spaces.py | 109 +++++++++++++++++++--------- 5 files changed, 101 insertions(+), 39 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e16e4b2..039a094 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 2.1.0a2 (WIP) +Release 2.1.0a3 (WIP) -------------------------- Breaking Changes: @@ -26,6 +26,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Relaxed check in logger, that was causing issue on Windows with colorama +- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer) Deprecations: ^^^^^^^^^^^^^ @@ -34,6 +35,7 @@ Others: ^^^^^^^ - Updated GitHub issue templates - Fix typo in gym patch error message (@lukashass) +- Refactor ``test_spaces.py`` tests Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index fe633e1..576e10a 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -207,7 +207,9 @@ class ReplayBuffer(BaseBuffer): else: self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype) - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) + self.actions = np.zeros( + (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype) + ) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -311,6 +313,21 @@ class ReplayBuffer(BaseBuffer): ) return ReplayBufferSamples(*tuple(map(self.to_torch, data))) + @staticmethod + def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike: + """ + Cast `np.float64` action datatype to `np.float32`, + keep the others dtype unchanged. + See GH#1572 for more information. + + :param dtype: The original action space dtype + :return: ``np.float32`` if the dtype was float64, + the original dtype otherwise. + """ + if dtype == np.float64: + return np.float32 + return dtype + class RolloutBuffer(BaseBuffer): """ @@ -543,7 +560,9 @@ class DictReplayBuffer(ReplayBuffer): for key, _obs_shape in self.obs_shape.items() } - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) + self.actions = np.zeros( + (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype) + ) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 6b44254..42e3d0d 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -245,7 +245,7 @@ class DQN(OffPolicyAlgorithm): if not deterministic and np.random.rand() < self.exploration_rate: if self.policy.is_vectorized_observation(observation): if isinstance(observation, dict): - n_batch = observation[list(observation.keys())[0]].shape[0] + n_batch = observation[next(iter(observation.keys()))].shape[0] else: n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 55c98c9..a4a6a87 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.1.0a2 +2.1.0a3 diff --git a/tests/test_spaces.py b/tests/test_spaces.py index fb70d0a..e4a9339 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,63 +1,67 @@ +from dataclasses import dataclass from typing import Dict, Optional import gymnasium as gym import numpy as np import pytest from gymnasium import spaces +from gymnasium.spaces.space import Space from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.evaluation import evaluate_policy +BOX_SPACE_FLOAT64 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float64) +BOX_SPACE_FLOAT32 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) -class DummyMultiDiscreteSpace(gym.Env): - def __init__(self, nvec): - super().__init__() - self.observation_space = spaces.MultiDiscrete(nvec) - self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + +@dataclass +class DummyEnv(gym.Env): + observation_space: Space + action_space: Space + + def step(self, action): + return self.observation_space.sample(), 0.0, False, False, {} def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} - def step(self, action): - return self.observation_space.sample(), 0.0, False, False, {} - -class DummyMultiBinary(gym.Env): - def __init__(self, n): - super().__init__() - self.observation_space = spaces.MultiBinary(n) - self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): - if seed is not None: - super().reset(seed=seed) - return self.observation_space.sample(), {} - - def step(self, action): - return self.observation_space.sample(), 0.0, False, False, {} - - -class DummyMultidimensionalAction(gym.Env): +class DummyMultidimensionalAction(DummyEnv): def __init__(self): - super().__init__() - self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) + super().__init__( + BOX_SPACE_FLOAT32, + spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32), + ) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): - if seed is not None: - super().reset(seed=seed) - return self.observation_space.sample(), {} - def step(self, action): - return self.observation_space.sample(), 0.0, False, False, {} +class DummyMultiBinary(DummyEnv): + def __init__(self, n): + super().__init__( + spaces.MultiBinary(n), + BOX_SPACE_FLOAT32, + ) + + +class DummyMultiDiscreteSpace(DummyEnv): + def __init__(self, nvec): + super().__init__( + spaces.MultiDiscrete(nvec), + BOX_SPACE_FLOAT32, + ) @pytest.mark.parametrize( - "env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2)), DummyMultidimensionalAction()] + "env", + [ + DummyMultiDiscreteSpace([4, 3]), + DummyMultiBinary(8), + DummyMultiBinary((3, 2)), + DummyMultidimensionalAction(), + ], ) def test_env(env): # Check the env used for testing @@ -127,3 +131,40 @@ def test_discrete_obs_space(model_class, env): else: kwargs = dict(n_steps=256) model_class("MlpPolicy", env, **kwargs).learn(256) + + +@pytest.mark.parametrize("model_class", [SAC, TD3, PPO, DDPG, A2C]) +@pytest.mark.parametrize( + "obs_space", + [ + BOX_SPACE_FLOAT32, + BOX_SPACE_FLOAT64, + spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT32}), + spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT64}), + ], +) +@pytest.mark.parametrize( + "action_space", + [ + BOX_SPACE_FLOAT32, + BOX_SPACE_FLOAT64, + ], +) +def test_float64_action_space(model_class, obs_space, action_space): + env = DummyEnv(obs_space, action_space) + env = gym.wrappers.TimeLimit(env, max_episode_steps=200) + if isinstance(env.observation_space, spaces.Dict): + policy = "MultiInputPolicy" + else: + policy = "MlpPolicy" + + if model_class in [PPO, A2C]: + kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12])) + else: + kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12])) + + model = model_class(policy, env, **kwargs) + model.learn(64) + initial_obs, _ = env.reset() + action, _ = model.predict(initial_obs, deterministic=False) + assert action.dtype == env.action_space.dtype