diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9d24bbe..6fafd99 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -27,6 +27,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed Atari wrapper that missed the reset condition (@luizapozzobon) +- Added the argument ``dtype`` (default to ``float32``) to the noise for consistency with gym action (@sidney-tio) - Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly) Deprecations: diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index 5e8632d..944408f 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Iterable, List, Optional import numpy as np +from numpy.typing import DTypeLike class ActionNoise(ABC): @@ -15,7 +16,7 @@ class ActionNoise(ABC): def reset(self) -> None: """ - call end of episode reset for the noise + Call end of episode reset for the noise """ pass @@ -26,19 +27,21 @@ class ActionNoise(ABC): class NormalActionNoise(ActionNoise): """ - A Gaussian action noise + A Gaussian action noise. - :param mean: the mean value of the noise - :param sigma: the scale of the noise (std here) + :param mean: Mean value of the noise + :param sigma: Scale of the noise (std here) + :param dtype: Type of the output noise """ - def __init__(self, mean: np.ndarray, sigma: np.ndarray): + def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLike = np.float32) -> None: self._mu = mean self._sigma = sigma + self._dtype = dtype super().__init__() def __call__(self) -> np.ndarray: - return np.random.normal(self._mu, self._sigma) + return np.random.normal(self._mu, self._sigma).astype(self._dtype) def __repr__(self) -> str: return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})" @@ -50,11 +53,12 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise): Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab - :param mean: the mean of the noise - :param sigma: the scale of the noise - :param theta: the rate of mean reversion - :param dt: the timestep for the noise - :param initial_noise: the initial value for the noise output, (if None: 0) + :param mean: Mean of the noise + :param sigma: Scale of the noise + :param theta: Rate of mean reversion + :param dt: Timestep for the noise + :param initial_noise: Initial value for the noise output, (if None: 0) + :param dtype: Type of the output noise """ def __init__( @@ -64,11 +68,13 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise): theta: float = 0.15, dt: float = 1e-2, initial_noise: Optional[np.ndarray] = None, - ): + dtype: DTypeLike = np.float32, + ) -> None: self._theta = theta self._mu = mean self._sigma = sigma self._dt = dt + self._dtype = dtype self.initial_noise = initial_noise self.noise_prev = np.zeros_like(self._mu) self.reset() @@ -81,7 +87,7 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise): + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape) ) self.noise_prev = noise - return noise + return noise.astype(self._dtype) def reset(self) -> None: """ @@ -97,11 +103,11 @@ class VectorizedActionNoise(ActionNoise): """ A Vectorized action noise for parallel environments. - :param base_noise: ActionNoise The noise generator to use - :param n_envs: The number of parallel environments + :param base_noise: Noise generator to use + :param n_envs: Number of parallel environments """ - def __init__(self, base_noise: ActionNoise, n_envs: int): + def __init__(self, base_noise: ActionNoise, n_envs: int) -> None: try: self.n_envs = int(n_envs) assert self.n_envs > 0 @@ -113,9 +119,9 @@ class VectorizedActionNoise(ActionNoise): def reset(self, indices: Optional[Iterable[int]] = None) -> None: """ - Reset all the noise processes, or those listed in indices + Reset all the noise processes, or those listed in indices. - :param indices: Optional[Iterable[int]] The indices to reset. Default: None. + :param indices: The indices to reset. Default: None. If the parameter is None, then all processes are reset to their initial position. """ if indices is None: @@ -129,7 +135,7 @@ class VectorizedActionNoise(ActionNoise): def __call__(self) -> np.ndarray: """ - Generate and stack the action noise from each noise object + Generate and stack the action noise from each noise object. """ noise = np.stack([noise() for noise in self.noises]) return noise diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py index 4c92d26..c165e48 100644 --- a/tests/test_deterministic.py +++ b/tests/test_deterministic.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 @@ -15,7 +16,9 @@ def test_deterministic_training_common(algo): kwargs = {"policy_kwargs": dict(net_arch=[64])} env_id = "Pendulum-v1" if algo in [TD3, SAC]: - kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4}) + kwargs.update( + {"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4} + ) else: if algo == DQN: env_id = "CartPole-v1"