2022-11-04 11:42:10 +00:00
# ONNX Runtime Training Guidelines
## 1. Installation and Configuration
Be noted: this mainly demonstrates set up steps for development, check [Torch-ORT ](https://github.com/pytorch/ort ) for end user set up experience.
Refer [https://onnxruntime.ai/ ](https://onnxruntime.ai/ ) to download training wheel. Or build from source:
```bash
export CUDA_HOME=/usr/local/cuda
export CUDNN_HOME=/usr/local/cuda
export CUDACXX=$CUDA_HOME/bin/nvcc
2023-04-20 19:56:45 +00:00
./build.sh --config RelWithDebInfo --use_cuda --enable_training --build_wheel --skip_tests --cuda_version=11.8 --parallel 8 --use_mpi
2022-11-04 11:42:10 +00:00
```
Install the Python wheel.
Configure ORTModule torch cpp extensions (**avoid** doing this in ORT code *repo root directory* ):
```bash
python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install
```
## 2. Use `ORTModule` to Accelerate Forward/Backward
Plug in your `torch.nn.Module` model with `ORTModule` to leverage ONNX Runtime fast training backend.
Sample usage as below:
```diff
model = build_model()
2022-11-17 10:15:02 +00:00
+ from onnxruntime.training.ortmodule import ORTModule
2022-11-04 11:42:10 +00:00
+ model = ORTModule(model)
```
> It is strongly recommended to wrap model with `ORTModule` before other module wrapper (for example, DeepSpeed, `torch.nn.parallel.DistributedDataParallel`, etc), which is validated in more scenarios.
> Be also noticed that, `ORTModule` is **NOT** compatible with `torch.nn.DataParallel` (not recommended to use in PyTorch usage). Please use `torch.nn.parallel.DistributedDataParallel` instead.
More options for **developers** .
```diff
model = build_model()
2022-11-17 10:15:02 +00:00
+ from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
2022-11-04 11:42:10 +00:00
+ model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name"))
```
2023-06-27 11:19:36 +00:00
Check [DebugOptions implementation ](../orttraining/orttraining/python/training/ortmodule/options.py ) for more details.
2022-11-04 11:42:10 +00:00
Tune ORTModule logging experience a bit (#18298)
### Tune logging experience a bit
After last time we update the ORTModule log experience, we found few
issues:
1. `INFO` level output too many things, including PyTorch exporter
verbose logs (tracing graphs) on every ranks. On this level, we only
want to
- Output a little bit more information to Users than `WARNING` level,
for example the memory recomputation recommendations or other
not-fully-ready features.
- Output a little bit more information for a quick diagnostic, collected
on rank-0 only.
2. ONNX Runtime logging filter during graph build, session init
sometimes will hide the issues (for example segement fault), there is no
useful information in `WARNING`/`INFO` for users to report to us. This
is not good!
3. Some of our devs like using `pdb` to debug Python code, but if we add
`import pdb; pdb.set_trace()` in models' code might hang when they use
`INFO` or `WARNING`, where exporter happens and all output got
redirected due to log filtering. The only workaround is to switch to
VERBOSE, which output toooooooooooo many logs.
The corresponding changes proposed here are:
1. For `INFO` logging,
- We only logs rank-0.
- We restricted the ORT backend logging level to be WARNING in this
case, because ORT backend code output way too many logs that should be
under verbose, while we cannot guarantee we can get them cleaned up
immediately once they are added.
- We output the PyTorch exporter verbose log (including tracing graph),
which is useful for a quick diagnostic when an issue happens.
2. Remove all logging filtering on ORT backend, then the segment fault
issue details will not be hidden once it happens again.
3. Introduced a `DEVINFO` logging,
- Log logs on all ranks
- Log ORT backend logging level INFO
- PyTorch exporter logging filtering are all turned OFF (to unblock the
pdb debugging).
4. Currently, to use Memory Optimizer, need use DEVINFO (which will
output ORT backend INFO log). So update memory optimizer document to
reflect this. https://github.com/microsoft/onnxruntime/pull/17481 will
update the requirement back to INFO for show memory optimization infos.
You can check
https://github.com/microsoft/onnxruntime/blob/pengwa/devinfo_level/docs/ORTModule_Training_Guidelines.md#log-level-explanations
for a better view of different log levels.
This PR also extract some changes from a bigger one
https://github.com/microsoft/onnxruntime/pull/17481, to reduce its
complexity for review.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
---------
Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
2023-11-08 09:42:50 +00:00
#### Log Level Explanations
< table >
< tr >
< th style = "width:20%" > Log Level< / th >
< th style = "width:80%" > Description< / th >
< / tr >
< tr >
< td >
`FATAL` | `ERROR` | `WARNING` (For Users)
< sup > `WARNING` is the default and recommended level for
< br > users.< / sup >
< / td >
< td >
- ONNX Runtime backend log level - `FATAL` | `ERROR` | `WARNING` .
- ORTModule log level - `FATAL` | `ERROR` | `WARNING` .
- Rank-0 log filtering is `ON` (e.g. logging on rank-0-only).
- PyTorch exporter export logs filtering is `ON` .
- PyTorch exporter verbose logs (including tracing graph) filtering is `ON` .
< / td >
< / tr >
< tr >
< td >
`INFO` (For Users | ORT Developers)
< sup > `INFO` is used for collecting experimental
< br > feature stats, or a little bit more error messages.< / sup >
< / td >
< td >
- ONNX Runtime backend log level - `WARNING` .
- ORTModule log level - `INFO` .
- Rank-0 log filtering is `ON` (e.g. logging on rank-0-only).
- PyTorch exporter export logs filtering is `ON` .
- PyTorch exporter verbose logs (including tracing graph) filtering is `OFF` .
< / td >
< / tr >
< tr >
< td >
`DEVINFO` (For ORT Developers)
< sup > `DEVINFO` is the recommended level for
< br > debugging purposes.< / sup >
< / td >
< td >
- ONNX Runtime backend log level - `INFO` .
- ORTModule log level - `INFO` .
- Rank-0 log filtering is `OFF` (e.g. logging on all ranks).
- PyTorch exporter export logs filtering is `OFF` .
- PyTorch exporter verbose logs (including tracing graph) filtering is `OFF` .
< / td >
< / tr >
< tr >
< td >
`VERBOSE` (For ORT Developers)
< sup > `VERBOSE` is the last resort for debugging
< br > hard problems.< / sup >
< / td >
< td >
- ONNX Runtime backend log level - `VERBOSE` .
- ORTModule log level - `VERBOSE` .
- Rank-0 log filtering is `OFF` (e.g. logging on all ranks).
- PyTorch exporter export logs filtering is `OFF` .
- PyTorch exporter verbose logs (including tracing graph) filtering is `OFF` .
< / td >
< / tr >
< / table >
2022-11-04 11:42:10 +00:00
### 2.1 Environment Variables
`ORTModule` provides environment variables targeting different use cases.
#### ORTMODULE_ONNX_OPSET_VERSION
- **Feature Area**: *ORTMODULE/ONNXOPSET*
- **Description**: By default, as ONNX Runtime released, the ONNX OPSET version to use will be updated periodically. For some customers, they want to stick to fixed OPSET where both performance and accuracy are well validated, this env variable can be used to control that.
```bash
export ORTMODULE_ONNX_OPSET_VERSION=14
```
#### ORTMODULE_FALLBACK_POLICY
- **Feature Area**: *ORTMODULE/FallbackToPytorch*
- **Description**: By default, if `ORTModule` fails to run the model using ONNX Runtime backend, it will fallback to use PyTorch to continue the training. At some point developers are optimizing the models and doing benchmarking, we want explicitly let ORT backend to run the model. The way we disable the retry:
```bash
export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"
```
#### ORTMODULE_LOG_LEVEL
- **Feature Area**: *ORTMODULE/DebugOptions*
- **Description**: Configure `ORTModule` log level. Defaults to LogLevel.WARNING, can be set one of "VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". The environment variable takes precedence if DebugOptions also sets log_level.
#### ORTMODULE_SAVE_ONNX_PATH
- **Feature Area**: *ORTMODULE/DebugOptions*
- **Description**: Configure `ORTModule` to save onnx models. Defaults to False.
The output directory of the onnx models by default is set to the current working directory. To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be set to the destination directory path.
#### ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
- **Description**: By default `ORTModule` will fail with exception when handling PythonOp export for some `'autograd.Function'` s (One example is torch CheckpointFunction). Set
this env variable to be `1` to explicitly allow it.
```bash
export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1
```
> Take the example of torch.utils.checkpoint.CheckpointFunction, if it is exported as PythonOp, the checkpointed computation may be computed by PyTorch, not ORT. This situation is especially important for big models such as GPT-2 where every few layers are wrapped to do re-computation, large number of computations are done by PyTorch. Currently a failure is reported to notify users it is possible `ORTModule` has less opportunities to optimize further.
> On the other hand, if the wrapped computation graph is small, it is reasonable to allow it.
> Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it.
2023-06-27 11:19:36 +00:00
#### ORTMODULE_ENABLE_CUSTOM_AUTOGRAD
2022-11-04 11:42:10 +00:00
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
- **Description**: By default, all torch.autograd.Function classes will be exported to ORT PythonOp. There are some cases where you might consider disable it. For example, if you confirmed those torch.autograd.Function classes defined computations that could be inline exported by PyTorch, and it is safe to use the inline exported ONNX graph to train, then you can disable it, as a result, ORT has more opportunities to optimize more.
```bash
2023-06-27 11:19:36 +00:00
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=1 # Enable
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=0 # Disable
2022-11-04 11:42:10 +00:00
```
An alternative to disable without using environment variable:
```python
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
enable_custom_autograd_support(False)
```
Optimize computation orders (#13672)
### Optimize computation orders
In `Roberta/Electra`, when `ClassificationHead` is used, there is
slicing operation on features on sequence_length dimensions, then loss
calculations only depend on this sliced data. This is a slicing at axis
1. Before slicing the shape is [batch, sequence_length, hidden], after
slicing, it becomes [batch , hidden_stage]
We had opportunities to bring this slicing earlier as much as possible,
by passing through simple elementwise ops (like Add/Div), or
Layernorm/Softmax(if their reduce axis is after the slicing axis), or
even MatMul's the left operand (if only it did not affect the last
dims).
For operators like Reshape/Transpose, it is special since they have
either data specified (after slicing we need update), or they have perm
specified, which requires the input rank remain unchanged. So for those
kinds of operators, we can remain the original rank, but just leave the
sliced dim to be 1, after the compute completed, we do a Squeeze.
```
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
```
src\transformers\models\roberta\modeling_roberta.py
src\transformers\models\electra\modeling_electra.py
#### Benchmark
A simple benchmark shows Robeta training latency dropped from 208ms ~
199ms. 4.5+% reduction.
More comprehensive tests are on the way.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2022-12-22 07:12:52 +00:00
#### ORTMODULE_ENABLE_COMPUTE_OPTIMIZER
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled then some computation can be saved. This env var can be used for disabling
the optimization to guarantee exactly same compute with baseline (for example PyTorch, when doing convergence parity
2023-05-23 05:08:05 +00:00
debugging).
Introduce padding inspector in ORTModule (#14652)
### Introduce padding inspector in ORTModule
In some Transformer-based LLM training recipes, high data sparsity is
observed due to 1). token padding (to max sequence length), 2). labels
contains many ignore_index for calculate loss.
This PR introduces a switch to enable data sparsity inspection, which
1). in short term, can inform training users to use techniques like
dynamic batching to amortize the issue.
2). in medium and longer term, also helps us (training team) to have
better understanding what our training customers' models looks like from
perspective of data sparsity (and potentially motivate us to improve
with runtime).
Here is an example of different data sparsity with same training model
arch, same training input, but with different user models.
**Low Embed Density, High Label Density Case - Sentence Classification**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/text-classification/run_glue.py
--model_name_or_path roberta-large-openai-detector --task_name mnli
--do_train --do_eval --max_seq_length 128 --per_device_train_batch_size
32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir
--output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137
--fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] |
| 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] |
| 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] |
| 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] |
| 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] |
| 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
<<<
```
**High Embed Density, Low Label Density Case - masked language model**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/language-modeling/run_mlm.py
--model_name_or_path bert-base-uncased --dataset_name wikitext
--dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10
--per_device_train_batch_size 8 --per_device_eval_batch_size 8
--do_train --do_eval --overwrite_output_dir --output_dir ./outputs/
--seed 1137 --fp16 --report_to none --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A |
| 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A |
| 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A |
| 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A |
| 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A |
<<<
```
#### Next Step
Let's see how we leverage the data sparsity for improvement.
Optimizations on the way around compute optimizer wave 2:
> Loss compute flops reduction.
> Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
```bash
2023-05-23 05:08:05 +00:00
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=0 # Disable
Introduce padding inspector in ORTModule (#14652)
### Introduce padding inspector in ORTModule
In some Transformer-based LLM training recipes, high data sparsity is
observed due to 1). token padding (to max sequence length), 2). labels
contains many ignore_index for calculate loss.
This PR introduces a switch to enable data sparsity inspection, which
1). in short term, can inform training users to use techniques like
dynamic batching to amortize the issue.
2). in medium and longer term, also helps us (training team) to have
better understanding what our training customers' models looks like from
perspective of data sparsity (and potentially motivate us to improve
with runtime).
Here is an example of different data sparsity with same training model
arch, same training input, but with different user models.
**Low Embed Density, High Label Density Case - Sentence Classification**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/text-classification/run_glue.py
--model_name_or_path roberta-large-openai-detector --task_name mnli
--do_train --do_eval --max_seq_length 128 --per_device_train_batch_size
32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir
--output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137
--fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] |
| 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] |
| 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] |
| 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] |
| 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] |
| 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
<<<
```
**High Embed Density, Low Label Density Case - masked language model**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/language-modeling/run_mlm.py
--model_name_or_path bert-base-uncased --dataset_name wikitext
--dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10
--per_device_train_batch_size 8 --per_device_eval_batch_size 8
--do_train --do_eval --overwrite_output_dir --output_dir ./outputs/
--seed 1137 --fp16 --report_to none --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A |
| 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A |
| 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A |
| 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A |
| 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A |
<<<
```
#### Next Step
Let's see how we leverage the data sparsity for improvement.
Optimizations on the way around compute optimizer wave 2:
> Loss compute flops reduction.
> Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
```
2023-05-23 05:08:05 +00:00
#### ORTMODULE_ENABLE_SPARSE_OPTIMIZER
2023-04-13 05:02:12 +00:00
2023-05-23 05:08:05 +00:00
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the input data sparsity
2023-06-19 12:34:53 +00:00
based performance optimizations, including embedding sparsity and label sparsity.
2023-07-11 01:07:35 +00:00
This optimization is applicable when using optimum, which has an implementation of the ModuleWithLoss class that wraps the HuggingFace Training that allows loss computation inside ONNX Runtime (ORT).
If you're not using optimum but want to implement a similar wrapper in your codebase to compute the loss inside ONNX Runtime (ORT), you can refer to this [Link ](ORTModule_ModuleWithLoss_Wrapper.md ) for detailed steps and guidelines on how to achieve this.
2023-05-23 05:08:05 +00:00
```bash
export ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_SPARSE_OPTIMIZER=0 # Disable
```
#### ORTMODULE_PRINT_INPUT_DENSITY
Introduce padding inspector in ORTModule (#14652)
### Introduce padding inspector in ORTModule
In some Transformer-based LLM training recipes, high data sparsity is
observed due to 1). token padding (to max sequence length), 2). labels
contains many ignore_index for calculate loss.
This PR introduces a switch to enable data sparsity inspection, which
1). in short term, can inform training users to use techniques like
dynamic batching to amortize the issue.
2). in medium and longer term, also helps us (training team) to have
better understanding what our training customers' models looks like from
perspective of data sparsity (and potentially motivate us to improve
with runtime).
Here is an example of different data sparsity with same training model
arch, same training input, but with different user models.
**Low Embed Density, High Label Density Case - Sentence Classification**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/text-classification/run_glue.py
--model_name_or_path roberta-large-openai-detector --task_name mnli
--do_train --do_eval --max_seq_length 128 --per_device_train_batch_size
32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir
--output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137
--fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] |
| 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] |
| 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] |
| 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] |
| 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] |
| 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
<<<
```
**High Embed Density, Low Label Density Case - masked language model**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/language-modeling/run_mlm.py
--model_name_or_path bert-base-uncased --dataset_name wikitext
--dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10
--per_device_train_batch_size 8 --per_device_eval_batch_size 8
--do_train --do_eval --overwrite_output_dir --output_dir ./outputs/
--seed 1137 --fp16 --report_to none --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A |
| 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A |
| 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A |
| 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A |
| 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A |
<<<
```
#### Next Step
Let's see how we leverage the data sparsity for improvement.
Optimizations on the way around compute optimizer wave 2:
> Loss compute flops reduction.
> Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
2023-05-23 05:08:05 +00:00
- **Feature Area**: *ORTMODULE/RuntimeInspector*
2023-06-15 07:45:36 +00:00
- **Description**: By default, this is disabled. This env var can be used for printing the input data sparsity
2023-05-23 05:08:05 +00:00
inspection results to standard outputs.
Introduce padding inspector in ORTModule (#14652)
### Introduce padding inspector in ORTModule
In some Transformer-based LLM training recipes, high data sparsity is
observed due to 1). token padding (to max sequence length), 2). labels
contains many ignore_index for calculate loss.
This PR introduces a switch to enable data sparsity inspection, which
1). in short term, can inform training users to use techniques like
dynamic batching to amortize the issue.
2). in medium and longer term, also helps us (training team) to have
better understanding what our training customers' models looks like from
perspective of data sparsity (and potentially motivate us to improve
with runtime).
Here is an example of different data sparsity with same training model
arch, same training input, but with different user models.
**Low Embed Density, High Label Density Case - Sentence Classification**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/text-classification/run_glue.py
--model_name_or_path roberta-large-openai-detector --task_name mnli
--do_train --do_eval --max_seq_length 128 --per_device_train_batch_size
32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir
--output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137
--fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] |
| 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] |
| 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] |
| 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] |
| 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] |
| 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
<<<
```
**High Embed Density, Low Label Density Case - masked language model**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/language-modeling/run_mlm.py
--model_name_or_path bert-base-uncased --dataset_name wikitext
--dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10
--per_device_train_batch_size 8 --per_device_eval_batch_size 8
--do_train --do_eval --overwrite_output_dir --output_dir ./outputs/
--seed 1137 --fp16 --report_to none --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A |
| 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A |
| 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A |
| 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A |
| 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A |
<<<
```
#### Next Step
Let's see how we leverage the data sparsity for improvement.
Optimizations on the way around compute optimizer wave 2:
> Loss compute flops reduction.
> Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
```bash
2023-05-23 05:08:05 +00:00
export ORTMODULE_PRINT_INPUT_DENSITY=1 # Enable
export ORTMODULE_PRINT_INPUT_DENSITY=0 # Disable
Introduce padding inspector in ORTModule (#14652)
### Introduce padding inspector in ORTModule
In some Transformer-based LLM training recipes, high data sparsity is
observed due to 1). token padding (to max sequence length), 2). labels
contains many ignore_index for calculate loss.
This PR introduces a switch to enable data sparsity inspection, which
1). in short term, can inform training users to use techniques like
dynamic batching to amortize the issue.
2). in medium and longer term, also helps us (training team) to have
better understanding what our training customers' models looks like from
perspective of data sparsity (and potentially motivate us to improve
with runtime).
Here is an example of different data sparsity with same training model
arch, same training input, but with different user models.
**Low Embed Density, High Label Density Case - Sentence Classification**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/text-classification/run_glue.py
--model_name_or_path roberta-large-openai-detector --task_name mnli
--do_train --do_eval --max_seq_length 128 --per_device_train_batch_size
32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir
--output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137
--fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] |
| 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] |
| 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] |
| 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] |
| 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] |
| 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
<<<
```
**High Embed Density, Low Label Density Case - masked language model**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/language-modeling/run_mlm.py
--model_name_or_path bert-base-uncased --dataset_name wikitext
--dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10
--per_device_train_batch_size 8 --per_device_eval_batch_size 8
--do_train --do_eval --overwrite_output_dir --output_dir ./outputs/
--seed 1137 --fp16 --report_to none --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A |
| 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A |
| 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A |
| 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A |
| 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A |
<<<
```
#### Next Step
Let's see how we leverage the data sparsity for improvement.
Optimizations on the way around compute optimizer wave 2:
> Loss compute flops reduction.
> Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
```
Optimize computation orders (#13672)
### Optimize computation orders
In `Roberta/Electra`, when `ClassificationHead` is used, there is
slicing operation on features on sequence_length dimensions, then loss
calculations only depend on this sliced data. This is a slicing at axis
1. Before slicing the shape is [batch, sequence_length, hidden], after
slicing, it becomes [batch , hidden_stage]
We had opportunities to bring this slicing earlier as much as possible,
by passing through simple elementwise ops (like Add/Div), or
Layernorm/Softmax(if their reduce axis is after the slicing axis), or
even MatMul's the left operand (if only it did not affect the last
dims).
For operators like Reshape/Transpose, it is special since they have
either data specified (after slicing we need update), or they have perm
specified, which requires the input rank remain unchanged. So for those
kinds of operators, we can remain the original rank, but just leave the
sliced dim to be 1, after the compute completed, we do a Squeeze.
```
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
```
src\transformers\models\roberta\modeling_roberta.py
src\transformers\models\electra\modeling_electra.py
#### Benchmark
A simple benchmark shows Robeta training latency dropped from 208ms ~
199ms. 4.5+% reduction.
More comprehensive tests are on the way.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2022-12-22 07:12:52 +00:00
2023-06-15 07:45:36 +00:00
#### ORTMODULE_PRINT_MEMORY_STATS
- **Feature Area**: *ORTMODULE/RuntimeInspector*
- **Description**: By default, this is disabled. This env var can be used for printing the memory inspection results
to standard outputs.
```bash
export ORTMODULE_PRINT_MEMORY_STATS=1 # Enable
export ORTMODULE_PRINT_MEMORY_STATS=0 # Disable
```
2023-06-19 12:34:53 +00:00
#### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input
data sparsity based performance optimizations.
```bash
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER=0 # Disable
```
2023-07-27 16:00:43 +00:00
#### ORTMODULE_CACHE_DIR
- **Feature Area**: *ORTMODULE/RuntimeOptions*
- **Description**: By default, this is disabled. This env vars can be used to cache the exported model for future runs. This optimization is intended to reduce experimentation time by re-using the PyTorch->ONNX exported model architecture when available.
```bash
export ORTMODULE_CACHE_DIR="/path/to/cache_dir" # Enable
unset ORTMODULE_CACHE_DIR # Disable
```
2022-11-04 11:42:10 +00:00
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*
Q: *The model training hits OOM, even with minimum required batch size?*
Check [Memory Optimizer for ONNX Runtime Training ](Memory_Optimizer.md ) for how to leverage ORT's recomputation techniques.
## 3. Use `FusedAdam` to Accelerate Parameter Update
Parameter update is done by optimizers (for example AdamW) with many elementwise operations. `FusedAdam` launches the elementwise update kernels with multi-tensor apply, allowing batches of gradients applied to corresponding parameters for each time kernel launch.
Here is a sample switch from torch `AdamW` optimizer to `FusedAdam` .
```diff
model = build_model()
- optimizer = AdamW(model.parameters(), lr=1)
+ from onnxruntime.training.optim import FusedAdam
+ optimizer = FusedAdam(model.parameters(), lr=1)
```
Check [FusedAdam implementation ](../orttraining/orttraining/python/training/optim/fused_adam.py ) for more details.
## 4. Use `FP16_Optimizer` to Complement DeepSpeed/APEX
If user models utilize DeepSpeed or Apex libraries, ORT's `FP16_Optimizer` can be used to complement some inefficiencies introduced by them.
Use `FP16_Optimizer` with DeepSpeed ZeRO Optimizer:
```diff
optimizer = AdamW(model.parameters(), lr=1)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=False)
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
+ optimizer = FP16_Optimizer(optimizer)
```
Use `FP16_Optimizer` with Apex Optimizer:
```diff
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
+ optimizer = ORT_FP16_Optimizer(optimizer)
```
Check [FP16_Optimizer implementation ](../orttraining/orttraining/python/training/optim/fp16_optimizer.py ) for more details.
## 5. Putting All Together `ORTModule` + `FusedAdam` + `FP16_Optimizer`
```diff
model = build_model()
2022-11-17 10:15:02 +00:00
+ from onnxruntime.training.ortmodule import ORTModule
2022-11-04 11:42:10 +00:00
+ model = ORTModule(model)
- optimizer = AdamW(model.parameters(), lr=1)
+ from onnxruntime.training.optim import FusedAdam
+ optimizer = FusedAdam(model.parameters(), lr=1)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=False)
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
+ optimizer = FP16_Optimizer(optimizer)
```
2023-07-13 10:17:58 +00:00
## 6. Use OpenAI Triton to Compute ONNX Sub-graph
`ORTModule` provides a way to switch to OpenAI Triton for executing some Ops to further accelerate training.
### 6.1 Environment Variables
#### ORTMODULE_USE_TRITON
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling Triton optimization.
```bash
export ORTMODULE_USE_TRITON=1
```
#### ORTMODULE_ENABLE_TUNING
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling online Op tuning for those Ops that have multiple implementations on target EP.
```bash
export ORTMODULE_ENABLE_TUNING=1
```
#### ORTMODULE_MAX_TUNING_DURATION_MS
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: When `ORTMODULE_ENABLE_TUNING` is enabled, this env var can be used to set max tuning duration in ms to avoid long tuning time.
```bash
export ORTMODULE_MAX_TUNING_DURATION_MS=9999
```
#### ORTMODULE_TUNING_RESULTS_PATH
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: When `ORTMODULE_ENABLE_TUNING` is enabled, this env var can be used to specify where the online Op tuning results be saved for later use. By default the results will not be saved. When `ORTMODULE_ENABLE_TUNING` is NOT enabled, this env var can be used to specify where Op tuning results can be fetched as offline tuning results.
```bash
export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results
```
#### ORTMODULE_TRITON_DEBUG
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling Triton debug mode. All original and processed sub-graphs and corresponding generated Triton codes will be saved into a triton_debug folder under working directory.
```bash
export ORTMODULE_TRITON_DEBUG=1
```
## 7. One More Thing - `LoadBalancingDistributedBatchSampler`
2022-11-04 11:42:10 +00:00
`LoadBalancingDistributedBatchSampler` balances the data load across workers based on the sample's complexity.
This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from **straggler problem** . In such scenarios, the complexity function could be defined to return the length of the input sample sequence. The usage is similar to `torch.utils.data.DistributedSampler` , where each process loads a subset of the original dataset that is exclusive to it.
A sample shown below:
```python
from onnxruntime.training.utils.data import LoadBalancingDistributedSampler, \
LoadBalancingDistributedBatchSampler
sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn)
batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn)
loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
for epoch in range(start_epoch, n_epochs):
batch_sampler.set_epoch(epoch)
train(loader)
```
Check [LoadBalancingDistributedBatchSampler implementation ](../orttraining/orttraining/python/training/utils/data/sampler.py ) for more details.