mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
Fix to use float64 actions for off policy algorithms (#1572)
* Added test cases where off policy algorithms fail with float64 actionspace * casting observations and actions to `np.float32` to unify behaviour between `ReplayBuffer` and `RolloutBuffer`. Fixing issue #1145 * reformatted using black * making test more restrictive by checking models action is float64 * added changelog entry * undo cast of observations as `preprocessing.preprocess_obs()` casts them to float32 anyways. * - Casting to float32 only, if action.dtype is float64 - Added cast to `DictReplayBuffer` as well * Added tests for multiple variations of continuous action types and observation spaces * applied reformatting by `make commit-checks` * Added typing and comment referring to description in merge request * Apply linter for single element slice * Rename helper and refactor tests * Update changelog and docstring --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
72c124d907
commit
ba77dd7c61
5 changed files with 101 additions and 39 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 2.1.0a2 (WIP)
|
||||
Release 2.1.0a3 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -26,6 +26,7 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Relaxed check in logger, that was causing issue on Windows with colorama
|
||||
- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -34,6 +35,7 @@ Others:
|
|||
^^^^^^^
|
||||
- Updated GitHub issue templates
|
||||
- Fix typo in gym patch error message (@lukashass)
|
||||
- Refactor ``test_spaces.py`` tests
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -207,7 +207,9 @@ class ReplayBuffer(BaseBuffer):
|
|||
else:
|
||||
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
|
||||
|
||||
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
|
||||
self.actions = np.zeros(
|
||||
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
|
||||
)
|
||||
|
||||
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
|
|
@ -311,6 +313,21 @@ class ReplayBuffer(BaseBuffer):
|
|||
)
|
||||
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
|
||||
|
||||
@staticmethod
|
||||
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
|
||||
"""
|
||||
Cast `np.float64` action datatype to `np.float32`,
|
||||
keep the others dtype unchanged.
|
||||
See GH#1572 for more information.
|
||||
|
||||
:param dtype: The original action space dtype
|
||||
:return: ``np.float32`` if the dtype was float64,
|
||||
the original dtype otherwise.
|
||||
"""
|
||||
if dtype == np.float64:
|
||||
return np.float32
|
||||
return dtype
|
||||
|
||||
|
||||
class RolloutBuffer(BaseBuffer):
|
||||
"""
|
||||
|
|
@ -543,7 +560,9 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
for key, _obs_shape in self.obs_shape.items()
|
||||
}
|
||||
|
||||
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
|
||||
self.actions = np.zeros(
|
||||
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
|
||||
)
|
||||
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
if self.policy.is_vectorized_observation(observation):
|
||||
if isinstance(observation, dict):
|
||||
n_batch = observation[list(observation.keys())[0]].shape[0]
|
||||
n_batch = observation[next(iter(observation.keys()))].shape[0]
|
||||
else:
|
||||
n_batch = observation.shape[0]
|
||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.1.0a2
|
||||
2.1.0a3
|
||||
|
|
|
|||
|
|
@ -1,63 +1,67 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
BOX_SPACE_FLOAT64 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float64)
|
||||
BOX_SPACE_FLOAT32 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
|
||||
class DummyMultiDiscreteSpace(gym.Env):
|
||||
def __init__(self, nvec):
|
||||
super().__init__()
|
||||
self.observation_space = spaces.MultiDiscrete(nvec)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
|
||||
@dataclass
|
||||
class DummyEnv(gym.Env):
|
||||
observation_space: Space
|
||||
action_space: Space
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, False, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, False, {}
|
||||
|
||||
|
||||
class DummyMultiBinary(gym.Env):
|
||||
def __init__(self, n):
|
||||
super().__init__()
|
||||
self.observation_space = spaces.MultiBinary(n)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, False, {}
|
||||
|
||||
|
||||
class DummyMultidimensionalAction(gym.Env):
|
||||
class DummyMultidimensionalAction(DummyEnv):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
|
||||
super().__init__(
|
||||
BOX_SPACE_FLOAT32,
|
||||
spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32),
|
||||
)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
if seed is not None:
|
||||
super().reset(seed=seed)
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, False, {}
|
||||
class DummyMultiBinary(DummyEnv):
|
||||
def __init__(self, n):
|
||||
super().__init__(
|
||||
spaces.MultiBinary(n),
|
||||
BOX_SPACE_FLOAT32,
|
||||
)
|
||||
|
||||
|
||||
class DummyMultiDiscreteSpace(DummyEnv):
|
||||
def __init__(self, nvec):
|
||||
super().__init__(
|
||||
spaces.MultiDiscrete(nvec),
|
||||
BOX_SPACE_FLOAT32,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2)), DummyMultidimensionalAction()]
|
||||
"env",
|
||||
[
|
||||
DummyMultiDiscreteSpace([4, 3]),
|
||||
DummyMultiBinary(8),
|
||||
DummyMultiBinary((3, 2)),
|
||||
DummyMultidimensionalAction(),
|
||||
],
|
||||
)
|
||||
def test_env(env):
|
||||
# Check the env used for testing
|
||||
|
|
@ -127,3 +131,40 @@ def test_discrete_obs_space(model_class, env):
|
|||
else:
|
||||
kwargs = dict(n_steps=256)
|
||||
model_class("MlpPolicy", env, **kwargs).learn(256)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3, PPO, DDPG, A2C])
|
||||
@pytest.mark.parametrize(
|
||||
"obs_space",
|
||||
[
|
||||
BOX_SPACE_FLOAT32,
|
||||
BOX_SPACE_FLOAT64,
|
||||
spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT32}),
|
||||
spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT64}),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"action_space",
|
||||
[
|
||||
BOX_SPACE_FLOAT32,
|
||||
BOX_SPACE_FLOAT64,
|
||||
],
|
||||
)
|
||||
def test_float64_action_space(model_class, obs_space, action_space):
|
||||
env = DummyEnv(obs_space, action_space)
|
||||
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
|
||||
if isinstance(env.observation_space, spaces.Dict):
|
||||
policy = "MultiInputPolicy"
|
||||
else:
|
||||
policy = "MlpPolicy"
|
||||
|
||||
if model_class in [PPO, A2C]:
|
||||
kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12]))
|
||||
else:
|
||||
kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12]))
|
||||
|
||||
model = model_class(policy, env, **kwargs)
|
||||
model.learn(64)
|
||||
initial_obs, _ = env.reset()
|
||||
action, _ = model.predict(initial_obs, deterministic=False)
|
||||
assert action.dtype == env.action_space.dtype
|
||||
|
|
|
|||
Loading…
Reference in a new issue