Add the argument dtype (default to float32) to the noise (#1301)

* Fixed noise to return float32

* Updated changelog

* Fixed test to use numpy arrays instead of python floats

* Sorted imports for tests

* Added dtype to constructor

* Removed dtype parameter for VectorizedActionNoise

* __init__ -> None; Capitalize and period in docstring when needed; fix dtype type hint; dtype in docstring

* fix dtype type hint

* Update version

* Clarify changelog [skip ci]

* empty commit to run ci

* Update docs/misc/changelog.rst

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Sidney Tio 2023-02-07 20:42:14 +08:00 committed by GitHub
parent 2e4a45020e
commit 489b1fdaf2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 20 deletions

View file

@ -27,6 +27,7 @@ New Features:
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon) - Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
- Added the argument ``dtype`` (default to ``float32``) to the noise for consistency with gym action (@sidney-tio)
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly) - Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
Deprecations: Deprecations:

View file

@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Iterable, List, Optional from typing import Iterable, List, Optional
import numpy as np import numpy as np
from numpy.typing import DTypeLike
class ActionNoise(ABC): class ActionNoise(ABC):
@ -15,7 +16,7 @@ class ActionNoise(ABC):
def reset(self) -> None: def reset(self) -> None:
""" """
call end of episode reset for the noise Call end of episode reset for the noise
""" """
pass pass
@ -26,19 +27,21 @@ class ActionNoise(ABC):
class NormalActionNoise(ActionNoise): class NormalActionNoise(ActionNoise):
""" """
A Gaussian action noise A Gaussian action noise.
:param mean: the mean value of the noise :param mean: Mean value of the noise
:param sigma: the scale of the noise (std here) :param sigma: Scale of the noise (std here)
:param dtype: Type of the output noise
""" """
def __init__(self, mean: np.ndarray, sigma: np.ndarray): def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLike = np.float32) -> None:
self._mu = mean self._mu = mean
self._sigma = sigma self._sigma = sigma
self._dtype = dtype
super().__init__() super().__init__()
def __call__(self) -> np.ndarray: def __call__(self) -> np.ndarray:
return np.random.normal(self._mu, self._sigma) return np.random.normal(self._mu, self._sigma).astype(self._dtype)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})" return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
@ -50,11 +53,12 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
:param mean: the mean of the noise :param mean: Mean of the noise
:param sigma: the scale of the noise :param sigma: Scale of the noise
:param theta: the rate of mean reversion :param theta: Rate of mean reversion
:param dt: the timestep for the noise :param dt: Timestep for the noise
:param initial_noise: the initial value for the noise output, (if None: 0) :param initial_noise: Initial value for the noise output, (if None: 0)
:param dtype: Type of the output noise
""" """
def __init__( def __init__(
@ -64,11 +68,13 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
theta: float = 0.15, theta: float = 0.15,
dt: float = 1e-2, dt: float = 1e-2,
initial_noise: Optional[np.ndarray] = None, initial_noise: Optional[np.ndarray] = None,
): dtype: DTypeLike = np.float32,
) -> None:
self._theta = theta self._theta = theta
self._mu = mean self._mu = mean
self._sigma = sigma self._sigma = sigma
self._dt = dt self._dt = dt
self._dtype = dtype
self.initial_noise = initial_noise self.initial_noise = initial_noise
self.noise_prev = np.zeros_like(self._mu) self.noise_prev = np.zeros_like(self._mu)
self.reset() self.reset()
@ -81,7 +87,7 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
+ self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape) + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
) )
self.noise_prev = noise self.noise_prev = noise
return noise return noise.astype(self._dtype)
def reset(self) -> None: def reset(self) -> None:
""" """
@ -97,11 +103,11 @@ class VectorizedActionNoise(ActionNoise):
""" """
A Vectorized action noise for parallel environments. A Vectorized action noise for parallel environments.
:param base_noise: ActionNoise The noise generator to use :param base_noise: Noise generator to use
:param n_envs: The number of parallel environments :param n_envs: Number of parallel environments
""" """
def __init__(self, base_noise: ActionNoise, n_envs: int): def __init__(self, base_noise: ActionNoise, n_envs: int) -> None:
try: try:
self.n_envs = int(n_envs) self.n_envs = int(n_envs)
assert self.n_envs > 0 assert self.n_envs > 0
@ -113,9 +119,9 @@ class VectorizedActionNoise(ActionNoise):
def reset(self, indices: Optional[Iterable[int]] = None) -> None: def reset(self, indices: Optional[Iterable[int]] = None) -> None:
""" """
Reset all the noise processes, or those listed in indices Reset all the noise processes, or those listed in indices.
:param indices: Optional[Iterable[int]] The indices to reset. Default: None. :param indices: The indices to reset. Default: None.
If the parameter is None, then all processes are reset to their initial position. If the parameter is None, then all processes are reset to their initial position.
""" """
if indices is None: if indices is None:
@ -129,7 +135,7 @@ class VectorizedActionNoise(ActionNoise):
def __call__(self) -> np.ndarray: def __call__(self) -> np.ndarray:
""" """
Generate and stack the action noise from each noise object Generate and stack the action noise from each noise object.
""" """
noise = np.stack([noise() for noise in self.noises]) noise = np.stack([noise() for noise in self.noises])
return noise return noise

View file

@ -1,3 +1,4 @@
import numpy as np
import pytest import pytest
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
@ -15,7 +16,9 @@ def test_deterministic_training_common(algo):
kwargs = {"policy_kwargs": dict(net_arch=[64])} kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v1" env_id = "Pendulum-v1"
if algo in [TD3, SAC]: if algo in [TD3, SAC]:
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4}) kwargs.update(
{"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4}
)
else: else:
if algo == DQN: if algo == DQN:
env_id = "CartPole-v1" env_id = "CartPole-v1"