mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Add the argument dtype (default to float32) to the noise (#1301)
* Fixed noise to return float32 * Updated changelog * Fixed test to use numpy arrays instead of python floats * Sorted imports for tests * Added dtype to constructor * Removed dtype parameter for VectorizedActionNoise * __init__ -> None; Capitalize and period in docstring when needed; fix dtype type hint; dtype in docstring * fix dtype type hint * Update version * Clarify changelog [skip ci] * empty commit to run ci * Update docs/misc/changelog.rst --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
2e4a45020e
commit
489b1fdaf2
3 changed files with 30 additions and 20 deletions
|
|
@ -27,6 +27,7 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
|
||||
- Added the argument ``dtype`` (default to ``float32``) to the noise for consistency with gym action (@sidney-tio)
|
||||
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
|
||||
|
||||
Deprecations:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
|
||||
class ActionNoise(ABC):
|
||||
|
|
@ -15,7 +16,7 @@ class ActionNoise(ABC):
|
|||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
call end of episode reset for the noise
|
||||
Call end of episode reset for the noise
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
@ -26,19 +27,21 @@ class ActionNoise(ABC):
|
|||
|
||||
class NormalActionNoise(ActionNoise):
|
||||
"""
|
||||
A Gaussian action noise
|
||||
A Gaussian action noise.
|
||||
|
||||
:param mean: the mean value of the noise
|
||||
:param sigma: the scale of the noise (std here)
|
||||
:param mean: Mean value of the noise
|
||||
:param sigma: Scale of the noise (std here)
|
||||
:param dtype: Type of the output noise
|
||||
"""
|
||||
|
||||
def __init__(self, mean: np.ndarray, sigma: np.ndarray):
|
||||
def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLike = np.float32) -> None:
|
||||
self._mu = mean
|
||||
self._sigma = sigma
|
||||
self._dtype = dtype
|
||||
super().__init__()
|
||||
|
||||
def __call__(self) -> np.ndarray:
|
||||
return np.random.normal(self._mu, self._sigma)
|
||||
return np.random.normal(self._mu, self._sigma).astype(self._dtype)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
|
||||
|
|
@ -50,11 +53,12 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
|
|||
|
||||
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
||||
|
||||
:param mean: the mean of the noise
|
||||
:param sigma: the scale of the noise
|
||||
:param theta: the rate of mean reversion
|
||||
:param dt: the timestep for the noise
|
||||
:param initial_noise: the initial value for the noise output, (if None: 0)
|
||||
:param mean: Mean of the noise
|
||||
:param sigma: Scale of the noise
|
||||
:param theta: Rate of mean reversion
|
||||
:param dt: Timestep for the noise
|
||||
:param initial_noise: Initial value for the noise output, (if None: 0)
|
||||
:param dtype: Type of the output noise
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -64,11 +68,13 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
|
|||
theta: float = 0.15,
|
||||
dt: float = 1e-2,
|
||||
initial_noise: Optional[np.ndarray] = None,
|
||||
):
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> None:
|
||||
self._theta = theta
|
||||
self._mu = mean
|
||||
self._sigma = sigma
|
||||
self._dt = dt
|
||||
self._dtype = dtype
|
||||
self.initial_noise = initial_noise
|
||||
self.noise_prev = np.zeros_like(self._mu)
|
||||
self.reset()
|
||||
|
|
@ -81,7 +87,7 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
|
|||
+ self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
|
||||
)
|
||||
self.noise_prev = noise
|
||||
return noise
|
||||
return noise.astype(self._dtype)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
|
|
@ -97,11 +103,11 @@ class VectorizedActionNoise(ActionNoise):
|
|||
"""
|
||||
A Vectorized action noise for parallel environments.
|
||||
|
||||
:param base_noise: ActionNoise The noise generator to use
|
||||
:param n_envs: The number of parallel environments
|
||||
:param base_noise: Noise generator to use
|
||||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(self, base_noise: ActionNoise, n_envs: int):
|
||||
def __init__(self, base_noise: ActionNoise, n_envs: int) -> None:
|
||||
try:
|
||||
self.n_envs = int(n_envs)
|
||||
assert self.n_envs > 0
|
||||
|
|
@ -113,9 +119,9 @@ class VectorizedActionNoise(ActionNoise):
|
|||
|
||||
def reset(self, indices: Optional[Iterable[int]] = None) -> None:
|
||||
"""
|
||||
Reset all the noise processes, or those listed in indices
|
||||
Reset all the noise processes, or those listed in indices.
|
||||
|
||||
:param indices: Optional[Iterable[int]] The indices to reset. Default: None.
|
||||
:param indices: The indices to reset. Default: None.
|
||||
If the parameter is None, then all processes are reset to their initial position.
|
||||
"""
|
||||
if indices is None:
|
||||
|
|
@ -129,7 +135,7 @@ class VectorizedActionNoise(ActionNoise):
|
|||
|
||||
def __call__(self) -> np.ndarray:
|
||||
"""
|
||||
Generate and stack the action noise from each noise object
|
||||
Generate and stack the action noise from each noise object.
|
||||
"""
|
||||
noise = np.stack([noise() for noise in self.noises])
|
||||
return noise
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
||||
|
|
@ -15,7 +16,9 @@ def test_deterministic_training_common(algo):
|
|||
kwargs = {"policy_kwargs": dict(net_arch=[64])}
|
||||
env_id = "Pendulum-v1"
|
||||
if algo in [TD3, SAC]:
|
||||
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
|
||||
kwargs.update(
|
||||
{"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4}
|
||||
)
|
||||
else:
|
||||
if algo == DQN:
|
||||
env_id = "CartPole-v1"
|
||||
|
|
|
|||
Loading…
Reference in a new issue