None as default value for env in HerReplayBuffer.sample + DQN batch size typing fix (#790)

* `env` to `None` by default in `HerReplayBuffer.sample` (#788)

* Fix DQN batch_size typing

* Fix changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
Quentin Gallouédec 2022-02-24 15:51:01 +01:00 committed by GitHub
parent 13fcb12471
commit db5366fb51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 7 deletions

View file

@ -29,6 +29,9 @@ Bug Fixes:
with very long keys.)
- Routing all the ``nn.Module`` calls through implicit rather than explict forward as per pytorch guidelines (@manuel-delverme)
- Fixed a bug in ``VecNormalize`` where error occurs when ``norm_obs`` is set to False for environment with dictionary observation (@buoyancy99)
- Set default ``env`` argument to ``None`` in ``HerReplayBuffer.sample`` (@qgallouedec)
- Fix ``batch_size`` typing in ``DQN`` (@qgallouedec)
- Fixed sample normalization in ``DictReplayBuffer`` (@qgallouedec)
Deprecations:
^^^^^^^^^^^^^
@ -88,7 +91,6 @@ Bug Fixes:
- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)
- Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error
- The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32
- Fixed sample normalization in ``DictReplayBuffer`` (@qgallouedec)
- Fixed a bug in ``VecFrameStack`` with channel first image envs, where the terminal observation would be wrongly created.
Deprecations:

View file

@ -66,7 +66,7 @@ class DQN(OffPolicyAlgorithm):
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 50000,
batch_size: Optional[int] = 32,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,

View file

@ -192,11 +192,7 @@ class HerReplayBuffer(DictReplayBuffer):
"""
raise NotImplementedError()
def sample(
self,
batch_size: int,
env: Optional[VecNormalize],
) -> DictReplayBufferSamples:
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
"""
Sample function for online sampling of HER transition,
this replaces the "regular" replay buffer ``sample()``