mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-29 23:07:07 +00:00
Update TD3/DDPG/DQN defaults for consistency (#1785)
* Update TD3/DDPG/DQN defaults for consistency * Update changelog
This commit is contained in:
parent
a653aec10d
commit
a9273f968e
5 changed files with 34 additions and 10 deletions
|
|
@ -3,12 +3,36 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
|
||||
Release 2.3.0a0 (WIP)
|
||||
Release 2.3.0a1 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- The defaults hyperparameters of ``TD3`` and ``DDPG`` have been changed to be more consistent with ``SAC``
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# SB3 < 2.3.0 default hyperparameters
|
||||
# model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100)
|
||||
# SB3 >= 2.3.0:
|
||||
model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256)
|
||||
|
||||
.. note::
|
||||
|
||||
Two inconsistencies remains: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC (for backward compatibility reasons, see `report on the influence of the network size <https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-Influence-of-policy-net--Vmlldzo2NDg1Mzk3>`_) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see `W&B report on the influence of the lr <https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-RL-Zoo-v2-3-0a0-vs-SB3-TD3-RL-Zoo-2-2-1---Vmlldzo2MjUyNTQx>`_)
|
||||
|
||||
|
||||
|
||||
- The default ``leanrning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
|
||||
# model = DQN("MlpPolicy", env, learning_start=50_000)
|
||||
# SB3 >= 2.3.0:
|
||||
model = DQN("MlpPolicy", env, learning_start=100)
|
||||
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -60,11 +60,11 @@ class DDPG(TD3):
|
|||
learning_rate: Union[float, Schedule] = 1e-3,
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 100,
|
||||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
|
||||
gradient_steps: int = -1,
|
||||
train_freq: Union[int, Tuple[int, str]] = 1,
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 1e-4,
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
learning_starts: int = 50000,
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 32,
|
||||
tau: float = 1.0,
|
||||
gamma: float = 0.99,
|
||||
|
|
|
|||
|
|
@ -83,11 +83,11 @@ class TD3(OffPolicyAlgorithm):
|
|||
learning_rate: Union[float, Schedule] = 1e-3,
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 100,
|
||||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
|
||||
gradient_steps: int = -1,
|
||||
train_freq: Union[int, Tuple[int, str]] = 1,
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.3.0a0
|
||||
2.3.0a1
|
||||
|
|
|
|||
Loading…
Reference in a new issue