From db5366fb51e5f7a8401f8adf30057ef7db0783f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 24 Feb 2022 15:51:01 +0100 Subject: [PATCH] `None` as default value for `env` in `HerReplayBuffer.sample` + `DQN` batch size typing fix (#790) * `env` to `None` by default in `HerReplayBuffer.sample` (#788) * Fix DQN batch_size typing * Fix changelog Co-authored-by: Antonin RAFFIN Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 4 +++- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/her/her_replay_buffer.py | 6 +----- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c5e07cf..ab0255a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -29,6 +29,9 @@ Bug Fixes: with very long keys.) - Routing all the ``nn.Module`` calls through implicit rather than explict forward as per pytorch guidelines (@manuel-delverme) - Fixed a bug in ``VecNormalize`` where error occurs when ``norm_obs`` is set to False for environment with dictionary observation (@buoyancy99) +- Set default ``env`` argument to ``None`` in ``HerReplayBuffer.sample`` (@qgallouedec) +- Fix ``batch_size`` typing in ``DQN`` (@qgallouedec) +- Fixed sample normalization in ``DictReplayBuffer`` (@qgallouedec) Deprecations: ^^^^^^^^^^^^^ @@ -88,7 +91,6 @@ Bug Fixes: - Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib) - Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error - The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32 -- Fixed sample normalization in ``DictReplayBuffer`` (@qgallouedec) - Fixed a bug in ``VecFrameStack`` with channel first image envs, where the terminal observation would be wrongly created. Deprecations: diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 668d729..a7aec6b 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -66,7 +66,7 @@ class DQN(OffPolicyAlgorithm): learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 50000, - batch_size: Optional[int] = 32, + batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: Union[int, Tuple[int, str]] = 4, diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 6a790d6..9a41477 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -192,11 +192,7 @@ class HerReplayBuffer(DictReplayBuffer): """ raise NotImplementedError() - def sample( - self, - batch_size: int, - env: Optional[VecNormalize], - ) -> DictReplayBufferSamples: + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: """ Sample function for online sampling of HER transition, this replaces the "regular" replay buffer ``sample()``