mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-27 03:11:57 +00:00
None as default value for env in HerReplayBuffer.sample + DQN batch size typing fix (#790)
* `env` to `None` by default in `HerReplayBuffer.sample` (#788) * Fix DQN batch_size typing * Fix changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
13fcb12471
commit
db5366fb51
3 changed files with 5 additions and 7 deletions
|
|
@ -29,6 +29,9 @@ Bug Fixes:
|
|||
with very long keys.)
|
||||
- Routing all the ``nn.Module`` calls through implicit rather than explict forward as per pytorch guidelines (@manuel-delverme)
|
||||
- Fixed a bug in ``VecNormalize`` where error occurs when ``norm_obs`` is set to False for environment with dictionary observation (@buoyancy99)
|
||||
- Set default ``env`` argument to ``None`` in ``HerReplayBuffer.sample`` (@qgallouedec)
|
||||
- Fix ``batch_size`` typing in ``DQN`` (@qgallouedec)
|
||||
- Fixed sample normalization in ``DictReplayBuffer`` (@qgallouedec)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -88,7 +91,6 @@ Bug Fixes:
|
|||
- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)
|
||||
- Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error
|
||||
- The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32
|
||||
- Fixed sample normalization in ``DictReplayBuffer`` (@qgallouedec)
|
||||
- Fixed a bug in ``VecFrameStack`` with channel first image envs, where the terminal observation would be wrongly created.
|
||||
|
||||
Deprecations:
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
learning_rate: Union[float, Schedule] = 1e-4,
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
learning_starts: int = 50000,
|
||||
batch_size: Optional[int] = 32,
|
||||
batch_size: int = 32,
|
||||
tau: float = 1.0,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = 4,
|
||||
|
|
|
|||
|
|
@ -192,11 +192,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def sample(
|
||||
self,
|
||||
batch_size: int,
|
||||
env: Optional[VecNormalize],
|
||||
) -> DictReplayBufferSamples:
|
||||
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
|
||||
"""
|
||||
Sample function for online sampling of HER transition,
|
||||
this replaces the "regular" replay buffer ``sample()``
|
||||
|
|
|
|||
Loading…
Reference in a new issue