From 6a8c9ddc8b14ce342ae22ac6eea2f47fc00a15f8 Mon Sep 17 00:00:00 2001 From: Alex Pasquali Date: Thu, 6 Oct 2022 13:36:06 +0200 Subject: [PATCH 1/2] Updated type hint and extended docstring in make_vec_env and make_atari_env (#1085) * Updated type hint and extended docstring in make_vec_env The function itself was already working with callables, but it wasn't considerent in the type hint of the function's signature. Extended the description of the wrapper_class parameter with a link to a Github issue containing more details on the matter. * Updated type hint in make_atari_env The function itself was already working with callables, but it wasn't considerent in the type hint of the function's signature. * Updated docstring in make_atari_env When modifying the type hint of the parameter 'env_id' (in this commit: fda6872f73c11075901ba88f2520f6316f818d1d), I forgot to update its description in the docstrig. Doing it now. * Removed redundant type in env_id's type hint in make_vec_env and make_atari_env Callable[..., gym.Env] already includes Type[gym.Env], as pointed out here: https://github.com/DLR-RM/stable-baselines3/pull/1085#issuecomment-1269685218 Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 27 +++++++++++++++++++++++++++ stable_baselines3/common/env_util.py | 11 +++++++---- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 36b77e5..96b2e2b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -2,6 +2,33 @@ Changelog ========== +Release 1.6.2 WIP +--------------------------- + +**Bug fix release** + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ + +SB3-Contrib +^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Fixed type hint of the ``env_id`` parameter in ``make_vec_env`` and ``make_atari_env`` (@AlexPasqua) + +Documentation: +^^^^^^^^^^^^^^ +- Extended docstring of the ``wrapper_class`` parameter in ``make_vec_env`` (@AlexPasqua) Release 1.6.1 (2022-09-29) --------------------------- diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 520c50a..eb893c6 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -36,7 +36,7 @@ def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool: def make_vec_env( - env_id: Union[str, Type[gym.Env]], + env_id: Union[str, Callable[..., gym.Env]], n_envs: int = 1, seed: Optional[int] = None, start_index: int = 0, @@ -53,7 +53,7 @@ def make_vec_env( By default it uses a ``DummyVecEnv`` which is usually faster than a ``SubprocVecEnv``. - :param env_id: the environment ID or the environment class + :param env_id: either the env ID, the env class or a callable returning an env :param n_envs: the number of environments you wish to have in parallel :param seed: the initial seed for the random number generator :param start_index: start rank index @@ -62,6 +62,9 @@ def make_vec_env( in a Monitor wrapper to provide additional information about training. :param wrapper_class: Additional wrapper to use on the environment. This can also be a function with single argument that wraps the environment in many things. + Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper. + if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior. + See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894 :param env_kwargs: Optional keyword argument to pass to the env constructor :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. @@ -106,7 +109,7 @@ def make_vec_env( def make_atari_env( - env_id: Union[str, Type[gym.Env]], + env_id: Union[str, Callable[..., gym.Env]], n_envs: int = 1, seed: Optional[int] = None, start_index: int = 0, @@ -121,7 +124,7 @@ def make_atari_env( Create a wrapped, monitored VecEnv for Atari. It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games. - :param env_id: the environment ID or the environment class + :param env_id: either the env ID, the env class or a callable returning an env :param n_envs: the number of environments you wish to have in parallel :param seed: the initial seed for the random number generator :param start_index: start rank index From 7c21b7918871c552f0079815f2bf17158fc4c8f0 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 6 Oct 2022 18:17:31 +0200 Subject: [PATCH 2/2] Add progress bar callback and argument (#1095) * Add progress bar callback and argument * Update doc * Update changelog * Upgrade pytype in docker image * Use tqdm.write in the logger to have cleaner output * Fix logger test * Fix when doing multiple calls to learn() * Address comments from code-review --- .gitlab-ci.yml | 2 + docs/guide/callbacks.rst | 23 ++++++++++ docs/guide/examples.rst | 4 +- docs/guide/quickstart.rst | 2 +- docs/misc/changelog.rst | 7 +-- setup.py | 3 ++ stable_baselines3/a2c/a2c.py | 2 + stable_baselines3/common/base_class.py | 16 +++++-- stable_baselines3/common/callbacks.py | 45 ++++++++++++++++++- stable_baselines3/common/logger.py | 11 ++++- .../common/off_policy_algorithm.py | 4 ++ .../common/on_policy_algorithm.py | 11 ++++- stable_baselines3/ddpg/ddpg.py | 2 + stable_baselines3/dqn/dqn.py | 2 + stable_baselines3/ppo/ppo.py | 2 + stable_baselines3/sac/sac.py | 2 + stable_baselines3/td3/td3.py | 2 + stable_baselines3/version.txt | 2 +- tests/test_callbacks.py | 4 +- 19 files changed, 131 insertions(+), 15 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 20953d2..dc09ed2 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -2,10 +2,12 @@ image: stablebaselines/stable-baselines3-cpu:1.4.1a0 type-check: script: + - pip install pytype --upgrade - make type pytest: script: + - pip install tqdm rich # for progress bar - python --version # MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error - MKL_THREADING_LAYER=GNU make pytest diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index 098f7c4..632743b 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -225,6 +225,29 @@ It will save the best model if ``best_model_save_path`` folder is specified and model = SAC("MlpPolicy", "Pendulum-v1") model.learn(5000, callback=eval_callback) +.. _ProgressBarCallback: + +ProgressBarCallback +^^^^^^^^^^^^^^^^^^^ + +Display a progress bar with the current progress, elapsed time and estimated remaining time. +This callback is integrated inside SB3 via the ``progress_bar`` argument of the ``learn()`` method. + +.. note:: + + This callback requires ``tqdm`` and ``rich`` packages to be installed. This is done automatically when using ``pip install stable-baselines3[extra]`` + + +.. code-block:: python + + from stable_baselines3 import PPO + from stable_baselines3.common.callbacks import ProgressBarCallback + + model = PPO("MlpPolicy", "Pendulum-v1") + # Display progress bar using the progress bar callback + # this is equivalent to model.learn(100_000, callback=ProgressBarCallback()) + model.learn(100_000, progress_bar=True) + .. _Callbacklist: diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index bcc206a..3433640 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -75,8 +75,8 @@ In the following example, we will train, save and load a DQN model on the Lunar # Instantiate the agent model = DQN("MlpPolicy", env, verbose=1) - # Train the agent - model.learn(total_timesteps=int(2e5)) + # Train the agent and display a progress bar + model.learn(total_timesteps=int(2e5), progress_bar=True) # Save the agent model.save("dqn_lunar") del model # delete trained model to demonstrate loading diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst index 3365b49..7ad9e0e 100644 --- a/docs/guide/quickstart.rst +++ b/docs/guide/quickstart.rst @@ -17,7 +17,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment: env = gym.make("CartPole-v1") model = A2C("MlpPolicy", env, verbose=1) - model.learn(total_timesteps=10000) + model.learn(total_timesteps=10_000) obs = env.reset() for i in range(1000): diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 96b2e2b..75d8df6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -2,22 +2,23 @@ Changelog ========== -Release 1.6.2 WIP +Release 1.6.2a0 (WIP) --------------------------- -**Bug fix release** - Breaking Changes: ^^^^^^^^^^^^^^^^^ New Features: ^^^^^^^^^^^^^ +- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages +- Added progress bar callback SB3-Contrib ^^^^^^^^^^^ Bug Fixes: ^^^^^^^^^^ +- ``self.num_timesteps`` was initialized properly only after the first call to ``on_step()`` for callbacks Deprecations: ^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index fd1ee47..7e40c44 100644 --- a/setup.py +++ b/setup.py @@ -127,6 +127,9 @@ setup( "tensorboard>=2.9.1", # Checking memory taken by replay buffer "psutil", + # For progress bar callback + "tqdm", + "rich", ], }, description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 8b8cecb..226f6bc 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -195,6 +195,7 @@ class A2C(OnPolicyAlgorithm): tb_log_name: str = "A2C", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> A2CSelf: return super().learn( @@ -207,4 +208,5 @@ class A2C(OnPolicyAlgorithm): tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 33d3fac..c265b7a 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -12,7 +12,7 @@ import numpy as np import torch as th from stable_baselines3.common import utils -from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback +from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback, ProgressBarCallback from stable_baselines3.common.env_util import is_wrapped from stable_baselines3.common.logger import Logger from stable_baselines3.common.monitor import Monitor @@ -371,6 +371,7 @@ class BaseAlgorithm(ABC): eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None, + progress_bar: bool = False, ) -> BaseCallback: """ :param callback: Callback(s) called at every step with state of the algorithm. @@ -378,6 +379,7 @@ class BaseAlgorithm(ABC): :param n_eval_episodes: How many episodes to play per evaluation :param n_eval_episodes: Number of episodes to rollout during evaluation. :param log_path: Path to a folder where the evaluations will be saved + :param progress_bar: Display a progress bar using tqdm and rich. :return: A hybrid callback calling `callback` and performing evaluation. """ # Convert a list of callbacks into a callback @@ -388,6 +390,10 @@ class BaseAlgorithm(ABC): if not isinstance(callback, BaseCallback): callback = ConvertCallback(callback) + # Add progress bar callback + if progress_bar: + callback = CallbackList([callback, ProgressBarCallback()]) + # Create eval callback in charge of the evaluation if eval_env is not None: eval_callback = EvalCallback( @@ -413,6 +419,7 @@ class BaseAlgorithm(ABC): log_path: Optional[str] = None, reset_num_timesteps: bool = True, tb_log_name: str = "run", + progress_bar: bool = False, ) -> Tuple[int, BaseCallback]: """ Initialize different variables needed for training. @@ -425,7 +432,8 @@ class BaseAlgorithm(ABC): :param log_path: Path to a folder where the evaluations will be saved :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute :param tb_log_name: the name of the run for tensorboard log - :return: + :param progress_bar: Display a progress bar using tqdm and rich. + :return: Total timesteps and callback(s) """ self.start_time = time.time_ns() @@ -464,7 +472,7 @@ class BaseAlgorithm(ABC): self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) # Create eval callback if needed - callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path) + callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, progress_bar) return total_timesteps, callback @@ -550,6 +558,7 @@ class BaseAlgorithm(ABC): n_eval_episodes: int = 5, eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> BaseAlgorithmSelf: """ Return a trained model. @@ -563,6 +572,7 @@ class BaseAlgorithm(ABC): :param n_eval_episodes: Number of episode to evaluate the agent :param eval_log_path: Path to a folder where the evaluations will be saved :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) + :param progress_bar: Display a progress bar using tqdm and rich. :return: the trained model """ diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index dfbf0ea..7f81a98 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -6,6 +6,17 @@ from typing import Any, Callable, Dict, List, Optional, Union import gym import numpy as np +try: + from tqdm import TqdmExperimentalWarning + + # Remove experimental warning + warnings.filterwarnings("ignore", category=TqdmExperimentalWarning) + from tqdm.rich import tqdm +except ImportError: + # Rich not installed, we only throw an error + # if the progress bar is used + tqdm = None + from stable_baselines3.common import base_class # pytype: disable=pyi-error from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization @@ -54,6 +65,8 @@ class BaseCallback(ABC): # Those are reference and will be updated automatically self.locals = locals_ self.globals = globals_ + # Update num_timesteps in case training was done before + self.num_timesteps = self.model.num_timesteps self._on_training_start() def _on_training_start(self) -> None: @@ -82,7 +95,6 @@ class BaseCallback(ABC): :return: If the callback returns False, training is aborted early. """ self.n_calls += 1 - # timesteps start at zero self.num_timesteps = self.model.num_timesteps return self._on_step() @@ -644,3 +656,34 @@ class StopTrainingOnNoModelImprovement(BaseCallback): ) return continue_training + + +class ProgressBarCallback(BaseCallback): + """ + Display a progress bar when training SB3 agent + using tqdm and rich packages. + """ + + def __init__(self) -> None: + super().__init__() + if tqdm is None: + raise ImportError( + "You must install tqdm and rich in order to use the progress bar callback. " + "It is included if you install stable-baselines with the extra packages: " + "`pip install stable-baselines3[extra]`" + ) + self.pbar = None + + def _on_training_start(self) -> None: + # Initialize progress bar + # Remove timesteps that were done in previous training sessions + self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps) + + def _on_step(self) -> bool: + # Update progress bar, we do num_envs steps per call to `env.step()` + self.pbar.update(self.training_env.num_envs) + return True + + def _on_training_end(self) -> None: + # Close progress bar + self.pbar.close() diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 31e5655..51b6f6b 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -18,6 +18,10 @@ try: except ImportError: SummaryWriter = None +try: + from tqdm import tqdm +except ImportError: + tqdm = None DEBUG = 10 INFO = 20 @@ -222,7 +226,12 @@ class HumanOutputFormat(KVWriter, SeqWriter): val_space = " " * (val_width - len(value)) lines.append(f"| {key}{key_space} | {value}{val_space} |") lines.append(dashes) - self.file.write("\n".join(lines) + "\n") + + if tqdm is not None and hasattr(self.file, "name") and self.file.name == "": + # Do not mess up with progress bar + tqdm.write("\n".join(lines) + "\n", file=sys.stdout, end="") + else: + self.file.write("\n".join(lines) + "\n") # Flush the output to the file self.file.flush() diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c23223d..99b5d22 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -276,6 +276,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): log_path: Optional[str] = None, reset_num_timesteps: bool = True, tb_log_name: str = "run", + progress_bar: bool = False, ) -> Tuple[int, BaseCallback]: """ cf `BaseAlgorithm`. @@ -318,6 +319,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): log_path, reset_num_timesteps, tb_log_name, + progress_bar, ) def learn( @@ -331,6 +333,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): tb_log_name: str = "run", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> OffPolicyAlgorithmSelf: total_timesteps, callback = self._setup_learn( @@ -342,6 +345,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): eval_log_path, reset_num_timesteps, tb_log_name, + progress_bar, ) callback.on_training_start(locals(), globals()) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 0589fe1..a91301d 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -237,11 +237,20 @@ class OnPolicyAlgorithm(BaseAlgorithm): tb_log_name: str = "OnPolicyAlgorithm", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> OnPolicyAlgorithmSelf: iteration = 0 total_timesteps, callback = self._setup_learn( - total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + total_timesteps, + eval_env, + callback, + eval_freq, + n_eval_episodes, + eval_log_path, + reset_num_timesteps, + tb_log_name, + progress_bar, ) callback.on_training_start(locals(), globals()) diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 531acd1..627a2e6 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -127,6 +127,7 @@ class DDPG(TD3): tb_log_name: str = "DDPG", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> DDPGSelf: return super().learn( @@ -139,4 +140,5 @@ class DDPG(TD3): tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index ea7d9f3..7263fdd 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -267,6 +267,7 @@ class DQN(OffPolicyAlgorithm): tb_log_name: str = "DQN", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> DQNSelf: return super().learn( @@ -279,6 +280,7 @@ class DQN(OffPolicyAlgorithm): tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) def _excluded_save_params(self) -> List[str]: diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index cfcdfb1..81094aa 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -309,6 +309,7 @@ class PPO(OnPolicyAlgorithm): tb_log_name: str = "PPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> PPOSelf: return super().learn( @@ -321,4 +322,5 @@ class PPO(OnPolicyAlgorithm): tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 8505e88..abd6879 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -301,6 +301,7 @@ class SAC(OffPolicyAlgorithm): tb_log_name: str = "SAC", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> SACSelf: return super().learn( @@ -313,6 +314,7 @@ class SAC(OffPolicyAlgorithm): tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) def _excluded_save_params(self) -> List[str]: diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 62e33f5..87d94b3 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -217,6 +217,7 @@ class TD3(OffPolicyAlgorithm): tb_log_name: str = "TD3", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> TD3Self: return super().learn( @@ -229,6 +230,7 @@ class TD3(OffPolicyAlgorithm): tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) def _excluded_save_params(self) -> List[str]: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 9c6d629..2f7c3d4 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.6.1 +1.6.2a0 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 2c7e0ba..f749d5a 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -71,8 +71,8 @@ def test_callbacks(tmp_path, model_class): assert event_callback.n_calls == model.num_timesteps model.learn(500, callback=None) - # Transform callback into a callback list automatically - model.learn(500, callback=[checkpoint_callback, eval_callback]) + # Transform callback into a callback list automatically and use progress bar + model.learn(500, callback=[checkpoint_callback, eval_callback], progress_bar=True) # Automatic wrapping, old way of doing callbacks model.learn(500, callback=lambda _locals, _globals: True)