Add the argument dtype (default to float32) to the noise (#1301)

* Fixed noise to return float32

* Updated changelog

* Fixed test to use numpy arrays instead of python floats

* Sorted imports for tests

* Added dtype to constructor

* Removed dtype parameter for VectorizedActionNoise

* __init__ -> None; Capitalize and period in docstring when needed; fix dtype type hint; dtype in docstring

* fix dtype type hint

* Update version

* Clarify changelog [skip ci]

* empty commit to run ci

* Update docs/misc/changelog.rst

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Sidney Tio 2023-02-07 20:42:14 +08:00 committed by GitHub
parent 2e4a45020e
commit 489b1fdaf2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 20 deletions

View file

@ -27,6 +27,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
- Added the argument ``dtype`` (default to ``float32``) to the noise for consistency with gym action (@sidney-tio)
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
Deprecations:

View file

@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Iterable, List, Optional
import numpy as np
from numpy.typing import DTypeLike
class ActionNoise(ABC):
@ -15,7 +16,7 @@ class ActionNoise(ABC):
def reset(self) -> None:
"""
call end of episode reset for the noise
Call end of episode reset for the noise
"""
pass
@ -26,19 +27,21 @@ class ActionNoise(ABC):
class NormalActionNoise(ActionNoise):
"""
A Gaussian action noise
A Gaussian action noise.
:param mean: the mean value of the noise
:param sigma: the scale of the noise (std here)
:param mean: Mean value of the noise
:param sigma: Scale of the noise (std here)
:param dtype: Type of the output noise
"""
def __init__(self, mean: np.ndarray, sigma: np.ndarray):
def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLike = np.float32) -> None:
self._mu = mean
self._sigma = sigma
self._dtype = dtype
super().__init__()
def __call__(self) -> np.ndarray:
return np.random.normal(self._mu, self._sigma)
return np.random.normal(self._mu, self._sigma).astype(self._dtype)
def __repr__(self) -> str:
return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
@ -50,11 +53,12 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
:param mean: the mean of the noise
:param sigma: the scale of the noise
:param theta: the rate of mean reversion
:param dt: the timestep for the noise
:param initial_noise: the initial value for the noise output, (if None: 0)
:param mean: Mean of the noise
:param sigma: Scale of the noise
:param theta: Rate of mean reversion
:param dt: Timestep for the noise
:param initial_noise: Initial value for the noise output, (if None: 0)
:param dtype: Type of the output noise
"""
def __init__(
@ -64,11 +68,13 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
theta: float = 0.15,
dt: float = 1e-2,
initial_noise: Optional[np.ndarray] = None,
):
dtype: DTypeLike = np.float32,
) -> None:
self._theta = theta
self._mu = mean
self._sigma = sigma
self._dt = dt
self._dtype = dtype
self.initial_noise = initial_noise
self.noise_prev = np.zeros_like(self._mu)
self.reset()
@ -81,7 +87,7 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
+ self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
)
self.noise_prev = noise
return noise
return noise.astype(self._dtype)
def reset(self) -> None:
"""
@ -97,11 +103,11 @@ class VectorizedActionNoise(ActionNoise):
"""
A Vectorized action noise for parallel environments.
:param base_noise: ActionNoise The noise generator to use
:param n_envs: The number of parallel environments
:param base_noise: Noise generator to use
:param n_envs: Number of parallel environments
"""
def __init__(self, base_noise: ActionNoise, n_envs: int):
def __init__(self, base_noise: ActionNoise, n_envs: int) -> None:
try:
self.n_envs = int(n_envs)
assert self.n_envs > 0
@ -113,9 +119,9 @@ class VectorizedActionNoise(ActionNoise):
def reset(self, indices: Optional[Iterable[int]] = None) -> None:
"""
Reset all the noise processes, or those listed in indices
Reset all the noise processes, or those listed in indices.
:param indices: Optional[Iterable[int]] The indices to reset. Default: None.
:param indices: The indices to reset. Default: None.
If the parameter is None, then all processes are reset to their initial position.
"""
if indices is None:
@ -129,7 +135,7 @@ class VectorizedActionNoise(ActionNoise):
def __call__(self) -> np.ndarray:
"""
Generate and stack the action noise from each noise object
Generate and stack the action noise from each noise object.
"""
noise = np.stack([noise() for noise in self.noises])
return noise

View file

@ -1,3 +1,4 @@
import numpy as np
import pytest
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
@ -15,7 +16,9 @@ def test_deterministic_training_common(algo):
kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v1"
if algo in [TD3, SAC]:
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
kwargs.update(
{"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4}
)
else:
if algo == DQN:
env_id = "CartPole-v1"