mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Merge branch 'master' into feat/mps-support
This commit is contained in:
commit
8d79e96e13
20 changed files with 162 additions and 16 deletions
|
|
@ -2,10 +2,12 @@ image: stablebaselines/stable-baselines3-cpu:1.4.1a0
|
|||
|
||||
type-check:
|
||||
script:
|
||||
- pip install pytype --upgrade
|
||||
- make type
|
||||
|
||||
pytest:
|
||||
script:
|
||||
- pip install tqdm rich # for progress bar
|
||||
- python --version
|
||||
# MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error
|
||||
- MKL_THREADING_LAYER=GNU make pytest
|
||||
|
|
|
|||
|
|
@ -225,6 +225,29 @@ It will save the best model if ``best_model_save_path`` folder is specified and
|
|||
model = SAC("MlpPolicy", "Pendulum-v1")
|
||||
model.learn(5000, callback=eval_callback)
|
||||
|
||||
.. _ProgressBarCallback:
|
||||
|
||||
ProgressBarCallback
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Display a progress bar with the current progress, elapsed time and estimated remaining time.
|
||||
This callback is integrated inside SB3 via the ``progress_bar`` argument of the ``learn()`` method.
|
||||
|
||||
.. note::
|
||||
|
||||
This callback requires ``tqdm`` and ``rich`` packages to be installed. This is done automatically when using ``pip install stable-baselines3[extra]``
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import ProgressBarCallback
|
||||
|
||||
model = PPO("MlpPolicy", "Pendulum-v1")
|
||||
# Display progress bar using the progress bar callback
|
||||
# this is equivalent to model.learn(100_000, callback=ProgressBarCallback())
|
||||
model.learn(100_000, progress_bar=True)
|
||||
|
||||
|
||||
.. _Callbacklist:
|
||||
|
||||
|
|
|
|||
|
|
@ -75,8 +75,8 @@ In the following example, we will train, save and load a DQN model on the Lunar
|
|||
|
||||
# Instantiate the agent
|
||||
model = DQN("MlpPolicy", env, verbose=1)
|
||||
# Train the agent
|
||||
model.learn(total_timesteps=int(2e5))
|
||||
# Train the agent and display a progress bar
|
||||
model.learn(total_timesteps=int(2e5), progress_bar=True)
|
||||
# Save the agent
|
||||
model.save("dqn_lunar")
|
||||
del model # delete trained model to demonstrate loading
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
|
|||
env = gym.make("CartPole-v1")
|
||||
|
||||
model = A2C("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
obs = env.reset()
|
||||
for i in range(1000):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,34 @@
|
|||
|
||||
Changelog
|
||||
==========
|
||||
Release 1.6.2a0 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages
|
||||
- Added progress bar callback
|
||||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- ``self.num_timesteps`` was initialized properly only after the first call to ``on_step()`` for callbacks
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Fixed type hint of the ``env_id`` parameter in ``make_vec_env`` and ``make_atari_env`` (@AlexPasqua)
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Extended docstring of the ``wrapper_class`` parameter in ``make_vec_env`` (@AlexPasqua)
|
||||
|
||||
Release 1.6.1 (2022-09-29)
|
||||
---------------------------
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -127,6 +127,9 @@ setup(
|
|||
"tensorboard>=2.9.1",
|
||||
# Checking memory taken by replay buffer
|
||||
"psutil",
|
||||
# For progress bar callback
|
||||
"tqdm",
|
||||
"rich",
|
||||
],
|
||||
},
|
||||
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
|
||||
|
|
|
|||
|
|
@ -195,6 +195,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
tb_log_name: str = "A2C",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> A2CSelf:
|
||||
|
||||
return super().learn(
|
||||
|
|
@ -207,4 +208,5 @@ class A2C(OnPolicyAlgorithm):
|
|||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import numpy as np
|
|||
import torch as th
|
||||
|
||||
from stable_baselines3.common import utils
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback, ProgressBarCallback
|
||||
from stable_baselines3.common.env_util import is_wrapped
|
||||
from stable_baselines3.common.logger import Logger
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
|
|
@ -371,6 +371,7 @@ class BaseAlgorithm(ABC):
|
|||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None,
|
||||
progress_bar: bool = False,
|
||||
) -> BaseCallback:
|
||||
"""
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
|
|
@ -378,6 +379,7 @@ class BaseAlgorithm(ABC):
|
|||
:param n_eval_episodes: How many episodes to play per evaluation
|
||||
:param n_eval_episodes: Number of episodes to rollout during evaluation.
|
||||
:param log_path: Path to a folder where the evaluations will be saved
|
||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
||||
"""
|
||||
# Convert a list of callbacks into a callback
|
||||
|
|
@ -388,6 +390,10 @@ class BaseAlgorithm(ABC):
|
|||
if not isinstance(callback, BaseCallback):
|
||||
callback = ConvertCallback(callback)
|
||||
|
||||
# Add progress bar callback
|
||||
if progress_bar:
|
||||
callback = CallbackList([callback, ProgressBarCallback()])
|
||||
|
||||
# Create eval callback in charge of the evaluation
|
||||
if eval_env is not None:
|
||||
eval_callback = EvalCallback(
|
||||
|
|
@ -413,6 +419,7 @@ class BaseAlgorithm(ABC):
|
|||
log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "run",
|
||||
progress_bar: bool = False,
|
||||
) -> Tuple[int, BaseCallback]:
|
||||
"""
|
||||
Initialize different variables needed for training.
|
||||
|
|
@ -425,7 +432,8 @@ class BaseAlgorithm(ABC):
|
|||
:param log_path: Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||
:param tb_log_name: the name of the run for tensorboard log
|
||||
:return:
|
||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
:return: Total timesteps and callback(s)
|
||||
"""
|
||||
self.start_time = time.time_ns()
|
||||
|
||||
|
|
@ -464,7 +472,7 @@ class BaseAlgorithm(ABC):
|
|||
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
||||
|
||||
# Create eval callback if needed
|
||||
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
|
||||
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, progress_bar)
|
||||
|
||||
return total_timesteps, callback
|
||||
|
||||
|
|
@ -550,6 +558,7 @@ class BaseAlgorithm(ABC):
|
|||
n_eval_episodes: int = 5,
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> BaseAlgorithmSelf:
|
||||
"""
|
||||
Return a trained model.
|
||||
|
|
@ -563,6 +572,7 @@ class BaseAlgorithm(ABC):
|
|||
:param n_eval_episodes: Number of episode to evaluate the agent
|
||||
:param eval_log_path: Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
:return: the trained model
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||
import gym
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from tqdm import TqdmExperimentalWarning
|
||||
|
||||
# Remove experimental warning
|
||||
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
|
||||
from tqdm.rich import tqdm
|
||||
except ImportError:
|
||||
# Rich not installed, we only throw an error
|
||||
# if the progress bar is used
|
||||
tqdm = None
|
||||
|
||||
from stable_baselines3.common import base_class # pytype: disable=pyi-error
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
|
||||
|
|
@ -54,6 +65,8 @@ class BaseCallback(ABC):
|
|||
# Those are reference and will be updated automatically
|
||||
self.locals = locals_
|
||||
self.globals = globals_
|
||||
# Update num_timesteps in case training was done before
|
||||
self.num_timesteps = self.model.num_timesteps
|
||||
self._on_training_start()
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
|
|
@ -82,7 +95,6 @@ class BaseCallback(ABC):
|
|||
:return: If the callback returns False, training is aborted early.
|
||||
"""
|
||||
self.n_calls += 1
|
||||
# timesteps start at zero
|
||||
self.num_timesteps = self.model.num_timesteps
|
||||
|
||||
return self._on_step()
|
||||
|
|
@ -644,3 +656,34 @@ class StopTrainingOnNoModelImprovement(BaseCallback):
|
|||
)
|
||||
|
||||
return continue_training
|
||||
|
||||
|
||||
class ProgressBarCallback(BaseCallback):
|
||||
"""
|
||||
Display a progress bar when training SB3 agent
|
||||
using tqdm and rich packages.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
if tqdm is None:
|
||||
raise ImportError(
|
||||
"You must install tqdm and rich in order to use the progress bar callback. "
|
||||
"It is included if you install stable-baselines with the extra packages: "
|
||||
"`pip install stable-baselines3[extra]`"
|
||||
)
|
||||
self.pbar = None
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
# Initialize progress bar
|
||||
# Remove timesteps that were done in previous training sessions
|
||||
self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Update progress bar, we do num_envs steps per call to `env.step()`
|
||||
self.pbar.update(self.training_env.num_envs)
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
# Close progress bar
|
||||
self.pbar.close()
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
|
|||
|
||||
|
||||
def make_vec_env(
|
||||
env_id: Union[str, Type[gym.Env]],
|
||||
env_id: Union[str, Callable[..., gym.Env]],
|
||||
n_envs: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
start_index: int = 0,
|
||||
|
|
@ -53,7 +53,7 @@ def make_vec_env(
|
|||
By default it uses a ``DummyVecEnv`` which is usually faster
|
||||
than a ``SubprocVecEnv``.
|
||||
|
||||
:param env_id: the environment ID or the environment class
|
||||
:param env_id: either the env ID, the env class or a callable returning an env
|
||||
:param n_envs: the number of environments you wish to have in parallel
|
||||
:param seed: the initial seed for the random number generator
|
||||
:param start_index: start rank index
|
||||
|
|
@ -62,6 +62,9 @@ def make_vec_env(
|
|||
in a Monitor wrapper to provide additional information about training.
|
||||
:param wrapper_class: Additional wrapper to use on the environment.
|
||||
This can also be a function with single argument that wraps the environment in many things.
|
||||
Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper.
|
||||
if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior.
|
||||
See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894
|
||||
:param env_kwargs: Optional keyword argument to pass to the env constructor
|
||||
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
|
||||
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
|
||||
|
|
@ -106,7 +109,7 @@ def make_vec_env(
|
|||
|
||||
|
||||
def make_atari_env(
|
||||
env_id: Union[str, Type[gym.Env]],
|
||||
env_id: Union[str, Callable[..., gym.Env]],
|
||||
n_envs: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
start_index: int = 0,
|
||||
|
|
@ -121,7 +124,7 @@ def make_atari_env(
|
|||
Create a wrapped, monitored VecEnv for Atari.
|
||||
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
|
||||
|
||||
:param env_id: the environment ID or the environment class
|
||||
:param env_id: either the env ID, the env class or a callable returning an env
|
||||
:param n_envs: the number of environments you wish to have in parallel
|
||||
:param seed: the initial seed for the random number generator
|
||||
:param start_index: start rank index
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ try:
|
|||
except ImportError:
|
||||
SummaryWriter = None
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
tqdm = None
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
|
|
@ -222,7 +226,12 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
val_space = " " * (val_width - len(value))
|
||||
lines.append(f"| {key}{key_space} | {value}{val_space} |")
|
||||
lines.append(dashes)
|
||||
self.file.write("\n".join(lines) + "\n")
|
||||
|
||||
if tqdm is not None and hasattr(self.file, "name") and self.file.name == "<stdout>":
|
||||
# Do not mess up with progress bar
|
||||
tqdm.write("\n".join(lines) + "\n", file=sys.stdout, end="")
|
||||
else:
|
||||
self.file.write("\n".join(lines) + "\n")
|
||||
|
||||
# Flush the output to the file
|
||||
self.file.flush()
|
||||
|
|
|
|||
|
|
@ -276,6 +276,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "run",
|
||||
progress_bar: bool = False,
|
||||
) -> Tuple[int, BaseCallback]:
|
||||
"""
|
||||
cf `BaseAlgorithm`.
|
||||
|
|
@ -318,6 +319,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
def learn(
|
||||
|
|
@ -331,6 +333,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
tb_log_name: str = "run",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> OffPolicyAlgorithmSelf:
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
|
|
@ -342,6 +345,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
eval_log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
|
|
|||
|
|
@ -237,11 +237,20 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
tb_log_name: str = "OnPolicyAlgorithm",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> OnPolicyAlgorithmSelf:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
|
||||
total_timesteps,
|
||||
eval_env,
|
||||
callback,
|
||||
eval_freq,
|
||||
n_eval_episodes,
|
||||
eval_log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ class DDPG(TD3):
|
|||
tb_log_name: str = "DDPG",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> DDPGSelf:
|
||||
|
||||
return super().learn(
|
||||
|
|
@ -139,4 +140,5 @@ class DDPG(TD3):
|
|||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -267,6 +267,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
tb_log_name: str = "DQN",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> DQNSelf:
|
||||
|
||||
return super().learn(
|
||||
|
|
@ -279,6 +280,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -309,6 +309,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
tb_log_name: str = "PPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> PPOSelf:
|
||||
|
||||
return super().learn(
|
||||
|
|
@ -321,4 +322,5 @@ class PPO(OnPolicyAlgorithm):
|
|||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -301,6 +301,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
tb_log_name: str = "SAC",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SACSelf:
|
||||
|
||||
return super().learn(
|
||||
|
|
@ -313,6 +314,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -217,6 +217,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
tb_log_name: str = "TD3",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> TD3Self:
|
||||
|
||||
return super().learn(
|
||||
|
|
@ -229,6 +230,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.6.1
|
||||
1.6.2a0
|
||||
|
|
|
|||
|
|
@ -71,8 +71,8 @@ def test_callbacks(tmp_path, model_class):
|
|||
assert event_callback.n_calls == model.num_timesteps
|
||||
|
||||
model.learn(500, callback=None)
|
||||
# Transform callback into a callback list automatically
|
||||
model.learn(500, callback=[checkpoint_callback, eval_callback])
|
||||
# Transform callback into a callback list automatically and use progress bar
|
||||
model.learn(500, callback=[checkpoint_callback, eval_callback], progress_bar=True)
|
||||
# Automatic wrapping, old way of doing callbacks
|
||||
model.learn(500, callback=lambda _locals, _globals: True)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue