Merge branch 'master' into feat/mps-support

This commit is contained in:
Quentin Gallouédec 2022-10-07 10:08:22 +02:00 committed by GitHub
commit 8d79e96e13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 162 additions and 16 deletions

View file

@ -2,10 +2,12 @@ image: stablebaselines/stable-baselines3-cpu:1.4.1a0
type-check:
script:
- pip install pytype --upgrade
- make type
pytest:
script:
- pip install tqdm rich # for progress bar
- python --version
# MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error
- MKL_THREADING_LAYER=GNU make pytest

View file

@ -225,6 +225,29 @@ It will save the best model if ``best_model_save_path`` folder is specified and
model = SAC("MlpPolicy", "Pendulum-v1")
model.learn(5000, callback=eval_callback)
.. _ProgressBarCallback:
ProgressBarCallback
^^^^^^^^^^^^^^^^^^^
Display a progress bar with the current progress, elapsed time and estimated remaining time.
This callback is integrated inside SB3 via the ``progress_bar`` argument of the ``learn()`` method.
.. note::
This callback requires ``tqdm`` and ``rich`` packages to be installed. This is done automatically when using ``pip install stable-baselines3[extra]``
.. code-block:: python
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import ProgressBarCallback
model = PPO("MlpPolicy", "Pendulum-v1")
# Display progress bar using the progress bar callback
# this is equivalent to model.learn(100_000, callback=ProgressBarCallback())
model.learn(100_000, progress_bar=True)
.. _Callbacklist:

View file

@ -75,8 +75,8 @@ In the following example, we will train, save and load a DQN model on the Lunar
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(2e5))
# Train the agent and display a progress bar
model.learn(total_timesteps=int(2e5), progress_bar=True)
# Save the agent
model.save("dqn_lunar")
del model # delete trained model to demonstrate loading

View file

@ -17,7 +17,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
env = gym.make("CartPole-v1")
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)
model.learn(total_timesteps=10_000)
obs = env.reset()
for i in range(1000):

View file

@ -2,6 +2,34 @@
Changelog
==========
Release 1.6.2a0 (WIP)
---------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages
- Added progress bar callback
SB3-Contrib
^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
- ``self.num_timesteps`` was initialized properly only after the first call to ``on_step()`` for callbacks
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Fixed type hint of the ``env_id`` parameter in ``make_vec_env`` and ``make_atari_env`` (@AlexPasqua)
Documentation:
^^^^^^^^^^^^^^
- Extended docstring of the ``wrapper_class`` parameter in ``make_vec_env`` (@AlexPasqua)
Release 1.6.1 (2022-09-29)
---------------------------

View file

@ -127,6 +127,9 @@ setup(
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
],
},
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",

View file

@ -195,6 +195,7 @@ class A2C(OnPolicyAlgorithm):
tb_log_name: str = "A2C",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> A2CSelf:
return super().learn(
@ -207,4 +208,5 @@ class A2C(OnPolicyAlgorithm):
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

View file

@ -12,7 +12,7 @@ import numpy as np
import torch as th
from stable_baselines3.common import utils
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback, ProgressBarCallback
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.logger import Logger
from stable_baselines3.common.monitor import Monitor
@ -371,6 +371,7 @@ class BaseAlgorithm(ABC):
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
progress_bar: bool = False,
) -> BaseCallback:
"""
:param callback: Callback(s) called at every step with state of the algorithm.
@ -378,6 +379,7 @@ class BaseAlgorithm(ABC):
:param n_eval_episodes: How many episodes to play per evaluation
:param n_eval_episodes: Number of episodes to rollout during evaluation.
:param log_path: Path to a folder where the evaluations will be saved
:param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
@ -388,6 +390,10 @@ class BaseAlgorithm(ABC):
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
# Add progress bar callback
if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()])
# Create eval callback in charge of the evaluation
if eval_env is not None:
eval_callback = EvalCallback(
@ -413,6 +419,7 @@ class BaseAlgorithm(ABC):
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
@ -425,7 +432,8 @@ class BaseAlgorithm(ABC):
:param log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:return:
:param progress_bar: Display a progress bar using tqdm and rich.
:return: Total timesteps and callback(s)
"""
self.start_time = time.time_ns()
@ -464,7 +472,7 @@ class BaseAlgorithm(ABC):
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, progress_bar)
return total_timesteps, callback
@ -550,6 +558,7 @@ class BaseAlgorithm(ABC):
n_eval_episodes: int = 5,
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> BaseAlgorithmSelf:
"""
Return a trained model.
@ -563,6 +572,7 @@ class BaseAlgorithm(ABC):
:param n_eval_episodes: Number of episode to evaluate the agent
:param eval_log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param progress_bar: Display a progress bar using tqdm and rich.
:return: the trained model
"""

View file

@ -6,6 +6,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
import gym
import numpy as np
try:
from tqdm import TqdmExperimentalWarning
# Remove experimental warning
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
from tqdm.rich import tqdm
except ImportError:
# Rich not installed, we only throw an error
# if the progress bar is used
tqdm = None
from stable_baselines3.common import base_class # pytype: disable=pyi-error
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
@ -54,6 +65,8 @@ class BaseCallback(ABC):
# Those are reference and will be updated automatically
self.locals = locals_
self.globals = globals_
# Update num_timesteps in case training was done before
self.num_timesteps = self.model.num_timesteps
self._on_training_start()
def _on_training_start(self) -> None:
@ -82,7 +95,6 @@ class BaseCallback(ABC):
:return: If the callback returns False, training is aborted early.
"""
self.n_calls += 1
# timesteps start at zero
self.num_timesteps = self.model.num_timesteps
return self._on_step()
@ -644,3 +656,34 @@ class StopTrainingOnNoModelImprovement(BaseCallback):
)
return continue_training
class ProgressBarCallback(BaseCallback):
"""
Display a progress bar when training SB3 agent
using tqdm and rich packages.
"""
def __init__(self) -> None:
super().__init__()
if tqdm is None:
raise ImportError(
"You must install tqdm and rich in order to use the progress bar callback. "
"It is included if you install stable-baselines with the extra packages: "
"`pip install stable-baselines3[extra]`"
)
self.pbar = None
def _on_training_start(self) -> None:
# Initialize progress bar
# Remove timesteps that were done in previous training sessions
self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps)
def _on_step(self) -> bool:
# Update progress bar, we do num_envs steps per call to `env.step()`
self.pbar.update(self.training_env.num_envs)
return True
def _on_training_end(self) -> None:
# Close progress bar
self.pbar.close()

View file

@ -36,7 +36,7 @@ def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
def make_vec_env(
env_id: Union[str, Type[gym.Env]],
env_id: Union[str, Callable[..., gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
@ -53,7 +53,7 @@ def make_vec_env(
By default it uses a ``DummyVecEnv`` which is usually faster
than a ``SubprocVecEnv``.
:param env_id: the environment ID or the environment class
:param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
@ -62,6 +62,9 @@ def make_vec_env(
in a Monitor wrapper to provide additional information about training.
:param wrapper_class: Additional wrapper to use on the environment.
This can also be a function with single argument that wraps the environment in many things.
Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper.
if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior.
See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
@ -106,7 +109,7 @@ def make_vec_env(
def make_atari_env(
env_id: Union[str, Type[gym.Env]],
env_id: Union[str, Callable[..., gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
@ -121,7 +124,7 @@ def make_atari_env(
Create a wrapped, monitored VecEnv for Atari.
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
:param env_id: the environment ID or the environment class
:param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index

View file

@ -18,6 +18,10 @@ try:
except ImportError:
SummaryWriter = None
try:
from tqdm import tqdm
except ImportError:
tqdm = None
DEBUG = 10
INFO = 20
@ -222,7 +226,12 @@ class HumanOutputFormat(KVWriter, SeqWriter):
val_space = " " * (val_width - len(value))
lines.append(f"| {key}{key_space} | {value}{val_space} |")
lines.append(dashes)
self.file.write("\n".join(lines) + "\n")
if tqdm is not None and hasattr(self.file, "name") and self.file.name == "<stdout>":
# Do not mess up with progress bar
tqdm.write("\n".join(lines) + "\n", file=sys.stdout, end="")
else:
self.file.write("\n".join(lines) + "\n")
# Flush the output to the file
self.file.flush()

View file

@ -276,6 +276,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
cf `BaseAlgorithm`.
@ -318,6 +319,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
log_path,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
def learn(
@ -331,6 +333,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
tb_log_name: str = "run",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> OffPolicyAlgorithmSelf:
total_timesteps, callback = self._setup_learn(
@ -342,6 +345,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
eval_log_path,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())

View file

@ -237,11 +237,20 @@ class OnPolicyAlgorithm(BaseAlgorithm):
tb_log_name: str = "OnPolicyAlgorithm",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> OnPolicyAlgorithmSelf:
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())

View file

@ -127,6 +127,7 @@ class DDPG(TD3):
tb_log_name: str = "DDPG",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> DDPGSelf:
return super().learn(
@ -139,4 +140,5 @@ class DDPG(TD3):
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

View file

@ -267,6 +267,7 @@ class DQN(OffPolicyAlgorithm):
tb_log_name: str = "DQN",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> DQNSelf:
return super().learn(
@ -279,6 +280,7 @@ class DQN(OffPolicyAlgorithm):
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:

View file

@ -309,6 +309,7 @@ class PPO(OnPolicyAlgorithm):
tb_log_name: str = "PPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> PPOSelf:
return super().learn(
@ -321,4 +322,5 @@ class PPO(OnPolicyAlgorithm):
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

View file

@ -301,6 +301,7 @@ class SAC(OffPolicyAlgorithm):
tb_log_name: str = "SAC",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SACSelf:
return super().learn(
@ -313,6 +314,7 @@ class SAC(OffPolicyAlgorithm):
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:

View file

@ -217,6 +217,7 @@ class TD3(OffPolicyAlgorithm):
tb_log_name: str = "TD3",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TD3Self:
return super().learn(
@ -229,6 +230,7 @@ class TD3(OffPolicyAlgorithm):
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:

View file

@ -1 +1 @@
1.6.1
1.6.2a0

View file

@ -71,8 +71,8 @@ def test_callbacks(tmp_path, model_class):
assert event_callback.n_calls == model.num_timesteps
model.learn(500, callback=None)
# Transform callback into a callback list automatically
model.learn(500, callback=[checkpoint_callback, eval_callback])
# Transform callback into a callback list automatically and use progress bar
model.learn(500, callback=[checkpoint_callback, eval_callback], progress_bar=True)
# Automatic wrapping, old way of doing callbacks
model.learn(500, callback=lambda _locals, _globals: True)