transformers/docs/source/main_classes/callback.rst

90 lines
4.1 KiB
ReStructuredText
Raw Normal View History

..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
Callbacks
-----------------------------------------------------------------------------------------------------------------------
Callbacks are objects that can customize the behavior of the training loop in the PyTorch
:class:`~transformers.Trainer` (this feature is not yet implemented in TensorFlow) that can inspect the training loop
state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early
stopping).
Callbacks are "read only" pieces of code, apart from the :class:`~transformers.TrainerControl` object they return, they
cannot change anything in the training loop. For customizations that require changes in the training loop, you should
subclass :class:`~transformers.Trainer` and override the methods you need (see :doc:`trainer` for examples).
By default a :class:`~transformers.Trainer` will use the following callbacks:
2020-10-13 14:00:20 +00:00
- :class:`~transformers.DefaultFlowCallback` which handles the default behavior for logging, saving and evaluation.
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProgressCallback` to display progress and print the
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
it's the second one).
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
or tensorboardX).
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
Mlflow integration callback (#8016) * Add MLflow integration class Add integration code for MLflow in integrations.py along with the code that checks that MLflow is installed. * Add MLflowCallback import Add import of MLflowCallback in trainer.py * Handle model argument Allow the callback to handle model argument and store model config items as hyperparameters. * Log parameters to MLflow in batches MLflow cannot log more than a hundred parameters at once. Code added to split the parameters into batches of 100 items and log the batches one by one. * Fix style * Add docs on MLflow callback * Fix issue with unfinished runs The "fluent" api used in MLflow integration allows only one run to be active at any given moment. If the Trainer is disposed off and a new one is created, but the training is not finished, it will refuse to log the results when the next trainer is created. * Add MLflow integration class Add integration code for MLflow in integrations.py along with the code that checks that MLflow is installed. * Add MLflowCallback import Add import of MLflowCallback in trainer.py * Handle model argument Allow the callback to handle model argument and store model config items as hyperparameters. * Log parameters to MLflow in batches MLflow cannot log more than a hundred parameters at once. Code added to split the parameters into batches of 100 items and log the batches one by one. * Fix style * Add docs on MLflow callback * Fix issue with unfinished runs The "fluent" api used in MLflow integration allows only one run to be active at any given moment. If the Trainer is disposed off and a new one is created, but the training is not finished, it will refuse to log the results when the next trainer is created.
2020-10-26 13:41:58 +00:00
- :class:`~transformers.integrations.MLflowCallback` if `mlflow <https://www.mlflow.org/>`__ is installed.
- :class:`~transformers.integrations.AzureMLCallback` if `azureml-sdk <https://pypi.org/project/azureml-sdk/>`__ is
installed.
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
Trainer's internal state via :class:`~transformers.TrainerState`, and can take some actions on the training loop via
:class:`~transformers.TrainerControl`.
Available Callbacks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Here is the list of the available :class:`~transformers.TrainerCallback` in the library:
.. autoclass:: transformers.integrations.CometCallback
:members: setup
.. autoclass:: transformers.DefaultFlowCallback
.. autoclass:: transformers.PrinterCallback
.. autoclass:: transformers.ProgressCallback
.. autoclass:: transformers.EarlyStoppingCallback
.. autoclass:: transformers.integrations.TensorBoardCallback
.. autoclass:: transformers.integrations.WandbCallback
:members: setup
Mlflow integration callback (#8016) * Add MLflow integration class Add integration code for MLflow in integrations.py along with the code that checks that MLflow is installed. * Add MLflowCallback import Add import of MLflowCallback in trainer.py * Handle model argument Allow the callback to handle model argument and store model config items as hyperparameters. * Log parameters to MLflow in batches MLflow cannot log more than a hundred parameters at once. Code added to split the parameters into batches of 100 items and log the batches one by one. * Fix style * Add docs on MLflow callback * Fix issue with unfinished runs The "fluent" api used in MLflow integration allows only one run to be active at any given moment. If the Trainer is disposed off and a new one is created, but the training is not finished, it will refuse to log the results when the next trainer is created. * Add MLflow integration class Add integration code for MLflow in integrations.py along with the code that checks that MLflow is installed. * Add MLflowCallback import Add import of MLflowCallback in trainer.py * Handle model argument Allow the callback to handle model argument and store model config items as hyperparameters. * Log parameters to MLflow in batches MLflow cannot log more than a hundred parameters at once. Code added to split the parameters into batches of 100 items and log the batches one by one. * Fix style * Add docs on MLflow callback * Fix issue with unfinished runs The "fluent" api used in MLflow integration allows only one run to be active at any given moment. If the Trainer is disposed off and a new one is created, but the training is not finished, it will refuse to log the results when the next trainer is created.
2020-10-26 13:41:58 +00:00
.. autoclass:: transformers.integrations.MLflowCallback
:members: setup
.. autoclass:: transformers.integrations.AzureMLCallback
TrainerCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainerCallback
:members:
TrainerState
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainerState
:members:
TrainerControl
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainerControl
:members: