mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
* IM compat. modif from old fork * mp her working, without offline sampling * update readme and doc * fix discrete action/obs space case * handle offline sampling * fix pos to be consistent with the old version * improve typing and docstring * fix discrete obs special case * new her, using episode uid * deal with full buffer * offline not implemented * info storage; compute_reward as arg; offline sampling error * offline sampling; timeout_termination; fix last_trans detection * rm max_episode_length from tests * fix loading and loading test * Fix episode sampling strategy * Episode interrupted not valid * Typo * Fix infos sampling, next_obs desired goals, offline sampling * update tests for multienvs * speed up code * handle timeout sampling when samping * give up ep_uid for ep_start and ep_lenght * speed up sampling * Improve docstring * Typos and renaming * Fix typing * Fix linter warnings * Renaming + add note * fix reward type * Fix future sampling strategy * Fix future goal selection strategy * env_fn as lambda * Re-fix linter warnings * Formatting * Fix offline sampling * restore the initial performance budget * Remove max_episode_length for HerReplayBuffer kwargs * SubprcVecEnv compat test * Dedicated SubrocVecEnv test rm n_envs from parametrization * Back to using the env arg instead of compute_reward * Up VecEnv import * fix lint warnings * fix docstring * Fix device issue * actor_loss_modifier in SAV and TD3 * Merge RewardModifier and ActorLossModifier into Surgeon * update surgeon for rnd * fix uninteded merge * fix uninteded merge * fix unintended merge * Rm unintended merge * Fix KeyError * Remove useless `all_inds` * Minor docstring format * Fix hint * speedup! * Speedup again * speedup * np.nonzero * fix env normalization * flat sampling for speedup * typo * drop online * format * remove observation from env_cheker (see #1335) * update changelog * default device to "auto" * add comment for info storage * add comment for ep_start and ep_length attributes * a[b][c] to a[b, c] * comment flatnonzero and unravel_index * update _sample_goals docstring * Fix future gaol sampling for split episode * add informative error message for learning_starts too small * use keyword arg for env * try fix pytye * Update stable_baselines3/common/off_policy_algorithm.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Add `copy_info_dict` option * Ignore pytype * Update changelog * Rename variables and improve documentation * Ignore new bug bear rule * Add note about future strategy * Add deprecation warning * Fix bug trying to pickle buffer kwargs --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
160 lines
4.5 KiB
ReStructuredText
160 lines
4.5 KiB
ReStructuredText
.. _her:
|
|
|
|
.. automodule:: stable_baselines3.her
|
|
|
|
|
|
HER
|
|
====
|
|
|
|
`Hindsight Experience Replay (HER) <https://arxiv.org/abs/1707.01495>`_
|
|
|
|
HER is an algorithm that works with off-policy methods (DQN, SAC, TD3 and DDPG for example).
|
|
HER uses the fact that even if a desired goal was not achieved, other goal may have been achieved during a rollout.
|
|
It creates "virtual" transitions by relabeling transitions (changing the desired goal) from past episodes.
|
|
|
|
|
|
.. warning::
|
|
|
|
Starting from Stable Baselines3 v1.1.0, ``HER`` is no longer a separate algorithm
|
|
but a replay buffer class ``HerReplayBuffer`` that must be passed to an off-policy algorithm
|
|
when using ``MultiInputPolicy`` (to have Dict observation support).
|
|
|
|
.. warning::
|
|
|
|
HER requires the environment to follow the legacy `gym_robotics.GoalEnv interface <https://github.com/Farama-Foundation/Gymnasium-Robotics/blob/a35b1c1fa669428bf640a2c7101e66eb1627ac3a/gym_robotics/core.py#L8>`_
|
|
In short, the ``gym.Env`` must have:
|
|
- a vectorized implementation of ``compute_reward()``
|
|
- a dictionary observation space with three keys: ``observation``, ``achieved_goal`` and ``desired_goal``
|
|
|
|
|
|
.. warning::
|
|
|
|
Because it needs access to ``env.compute_reward()``
|
|
``HER`` must be loaded with the env. If you just want to use the trained policy
|
|
without instantiating the environment, we recommend saving the policy only.
|
|
|
|
|
|
.. note::
|
|
|
|
Compared to other implementations, the ``future`` goal sampling strategy is inclusive:
|
|
the current transition can be used when re-sampling.
|
|
|
|
|
|
Notes
|
|
-----
|
|
|
|
- Original paper: https://arxiv.org/abs/1707.01495
|
|
- OpenAI paper: `Plappert et al. (2018)`_
|
|
- OpenAI blog post: https://openai.com/blog/ingredients-for-robotics-research/
|
|
|
|
|
|
.. _Plappert et al. (2018): https://arxiv.org/abs/1802.09464
|
|
|
|
Can I use?
|
|
----------
|
|
|
|
Please refer to the used model (DQN, QR-DQN, SAC, TQC, TD3, or DDPG) for that section.
|
|
|
|
Example
|
|
-------
|
|
|
|
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo `repository <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
|
|
|
|
.. code-block:: python
|
|
|
|
from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
|
|
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
|
|
from stable_baselines3.common.envs import BitFlippingEnv
|
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
|
|
|
model_class = DQN # works also with SAC, DDPG and TD3
|
|
N_BITS = 15
|
|
|
|
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
|
|
|
|
# Available strategies (cf paper): future, final, episode
|
|
goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE
|
|
|
|
# Initialize the model
|
|
model = model_class(
|
|
"MultiInputPolicy",
|
|
env,
|
|
replay_buffer_class=HerReplayBuffer,
|
|
# Parameters for HER
|
|
replay_buffer_kwargs=dict(
|
|
n_sampled_goal=4,
|
|
goal_selection_strategy=goal_selection_strategy,
|
|
),
|
|
verbose=1,
|
|
)
|
|
|
|
# Train the model
|
|
model.learn(1000)
|
|
|
|
model.save("./her_bit_env")
|
|
# Because it needs access to `env.compute_reward()`
|
|
# HER must be loaded with the env
|
|
model = model_class.load("./her_bit_env", env=env)
|
|
|
|
obs = env.reset()
|
|
for _ in range(100):
|
|
action, _ = model.predict(obs, deterministic=True)
|
|
obs, reward, done, _ = env.step(action)
|
|
|
|
if done:
|
|
obs = env.reset()
|
|
|
|
|
|
Results
|
|
-------
|
|
|
|
This implementation was tested on the `parking env <https://github.com/eleurent/highway-env>`_
|
|
using 3 seeds.
|
|
|
|
The complete learning curves are available in the `associated PR #120 <https://github.com/DLR-RM/stable-baselines3/pull/120>`_.
|
|
|
|
|
|
|
|
How to replicate the results?
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
|
|
|
|
.. code-block:: bash
|
|
|
|
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
|
cd rl-baselines3-zoo/
|
|
|
|
|
|
Run the benchmark:
|
|
|
|
.. code-block:: bash
|
|
|
|
python train.py --algo tqc --env parking-v0 --eval-episodes 10 --eval-freq 10000
|
|
|
|
|
|
Plot the results:
|
|
|
|
.. code-block:: bash
|
|
|
|
python scripts/all_plots.py -a tqc -e parking-v0 -f logs/ --no-million
|
|
|
|
|
|
Parameters
|
|
----------
|
|
|
|
HER Replay Buffer
|
|
-----------------
|
|
|
|
.. autoclass:: HerReplayBuffer
|
|
:members:
|
|
:inherited-members:
|
|
|
|
|
|
Goal Selection Strategies
|
|
-------------------------
|
|
|
|
.. autoclass:: GoalSelectionStrategy
|
|
:members:
|
|
:inherited-members:
|
|
:undoc-members:
|