Fix to use float64 actions for off policy algorithms (#1572)

* Added test cases where off policy algorithms fail with float64 actionspace

* casting observations and actions to `np.float32` to unify behaviour between `ReplayBuffer` and `RolloutBuffer`. Fixing issue #1145

* reformatted using black

* making test more restrictive by checking models action is float64

* added changelog entry

* undo cast of observations as `preprocessing.preprocess_obs()` casts them to float32 anyways.

* - Casting to float32 only, if action.dtype is float64
- Added cast to `DictReplayBuffer` as well

* Added tests for multiple variations of continuous action types and observation spaces

* applied reformatting by `make commit-checks`

* Added typing and comment referring to description in merge request

* Apply linter for single element slice

* Rename helper and refactor tests

* Update changelog and docstring

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Tobias Rohrer 2023-07-24 16:38:03 +02:00 committed by GitHub
parent 72c124d907
commit ba77dd7c61
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 101 additions and 39 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 2.1.0a2 (WIP)
Release 2.1.0a3 (WIP)
--------------------------
Breaking Changes:
@ -26,6 +26,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Relaxed check in logger, that was causing issue on Windows with colorama
- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer)
Deprecations:
^^^^^^^^^^^^^
@ -34,6 +35,7 @@ Others:
^^^^^^^
- Updated GitHub issue templates
- Fix typo in gym patch error message (@lukashass)
- Refactor ``test_spaces.py`` tests
Documentation:
^^^^^^^^^^^^^^

View file

@ -207,7 +207,9 @@ class ReplayBuffer(BaseBuffer):
else:
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@ -311,6 +313,21 @@ class ReplayBuffer(BaseBuffer):
)
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
@staticmethod
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
"""
Cast `np.float64` action datatype to `np.float32`,
keep the others dtype unchanged.
See GH#1572 for more information.
:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if dtype == np.float64:
return np.float32
return dtype
class RolloutBuffer(BaseBuffer):
"""
@ -543,7 +560,9 @@ class DictReplayBuffer(ReplayBuffer):
for key, _obs_shape in self.obs_shape.items()
}
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

View file

@ -245,7 +245,7 @@ class DQN(OffPolicyAlgorithm):
if not deterministic and np.random.rand() < self.exploration_rate:
if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
n_batch = observation[next(iter(observation.keys()))].shape[0]
else:
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])

View file

@ -1 +1 @@
2.1.0a2
2.1.0a3

View file

@ -1,63 +1,67 @@
from dataclasses import dataclass
from typing import Dict, Optional
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from gymnasium.spaces.space import Space
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
BOX_SPACE_FLOAT64 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float64)
BOX_SPACE_FLOAT32 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
class DummyMultiDiscreteSpace(gym.Env):
def __init__(self, nvec):
super().__init__()
self.observation_space = spaces.MultiDiscrete(nvec)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
@dataclass
class DummyEnv(gym.Env):
observation_space: Space
action_space: Space
def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}
def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}
class DummyMultiBinary(gym.Env):
def __init__(self, n):
super().__init__()
self.observation_space = spaces.MultiBinary(n)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}
def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}
class DummyMultidimensionalAction(gym.Env):
class DummyMultidimensionalAction(DummyEnv):
def __init__(self):
super().__init__()
self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
super().__init__(
BOX_SPACE_FLOAT32,
spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32),
)
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}
def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}
class DummyMultiBinary(DummyEnv):
def __init__(self, n):
super().__init__(
spaces.MultiBinary(n),
BOX_SPACE_FLOAT32,
)
class DummyMultiDiscreteSpace(DummyEnv):
def __init__(self, nvec):
super().__init__(
spaces.MultiDiscrete(nvec),
BOX_SPACE_FLOAT32,
)
@pytest.mark.parametrize(
"env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2)), DummyMultidimensionalAction()]
"env",
[
DummyMultiDiscreteSpace([4, 3]),
DummyMultiBinary(8),
DummyMultiBinary((3, 2)),
DummyMultidimensionalAction(),
],
)
def test_env(env):
# Check the env used for testing
@ -127,3 +131,40 @@ def test_discrete_obs_space(model_class, env):
else:
kwargs = dict(n_steps=256)
model_class("MlpPolicy", env, **kwargs).learn(256)
@pytest.mark.parametrize("model_class", [SAC, TD3, PPO, DDPG, A2C])
@pytest.mark.parametrize(
"obs_space",
[
BOX_SPACE_FLOAT32,
BOX_SPACE_FLOAT64,
spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT32}),
spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT64}),
],
)
@pytest.mark.parametrize(
"action_space",
[
BOX_SPACE_FLOAT32,
BOX_SPACE_FLOAT64,
],
)
def test_float64_action_space(model_class, obs_space, action_space):
env = DummyEnv(obs_space, action_space)
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
if isinstance(env.observation_space, spaces.Dict):
policy = "MultiInputPolicy"
else:
policy = "MlpPolicy"
if model_class in [PPO, A2C]:
kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12]))
else:
kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12]))
model = model_class(policy, env, **kwargs)
model.learn(64)
initial_obs, _ = env.reset()
action, _ = model.predict(initial_obs, deterministic=False)
assert action.dtype == env.action_space.dtype