mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
### Introduce padding inspector in ORTModule
In some Transformer-based LLM training recipes, high data sparsity is
observed due to 1). token padding (to max sequence length), 2). labels
contains many ignore_index for calculate loss.
This PR introduces a switch to enable data sparsity inspection, which
1). in short term, can inform training users to use techniques like
dynamic batching to amortize the issue.
2). in medium and longer term, also helps us (training team) to have
better understanding what our training customers' models looks like from
perspective of data sparsity (and potentially motivate us to improve
with runtime).
Here is an example of different data sparsity with same training model
arch, same training input, but with different user models.
**Low Embed Density, High Label Density Case - Sentence Classification**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/text-classification/run_glue.py
--model_name_or_path roberta-large-openai-detector --task_name mnli
--do_train --do_eval --max_seq_length 128 --per_device_train_batch_size
32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir
--output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137
--fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] |
| 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] |
| 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] |
| 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] |
| 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
| 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] |
| 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A |
<<<
```
**High Embed Density, Low Label Density Case - masked language model**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/language-modeling/run_mlm.py
--model_name_or_path bert-base-uncased --dataset_name wikitext
--dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10
--per_device_train_batch_size 8 --per_device_eval_batch_size 8
--do_train --do_eval --overwrite_output_dir --output_dir ./outputs/
--seed 1137 --fp16 --report_to none --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
| STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH |
| 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A |
| 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A |
| 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A |
| 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A |
| 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] |
| 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A |
<<<
```
#### Next Step
Let's see how we leverage the data sparsity for improvement.
Optimizations on the way around compute optimizer wave 2:
> Loss compute flops reduction.
> Flatten/Unflatten embedding tokens to save compute flops.
251 lines
11 KiB
Markdown
251 lines
11 KiB
Markdown
# ONNX Runtime Training Guidelines
|
|
|
|
## 1. Installation and Configuration
|
|
|
|
Be noted: this mainly demonstrates set up steps for development, check [Torch-ORT](https://github.com/pytorch/ort) for end user set up experience.
|
|
|
|
Refer [https://onnxruntime.ai/](https://onnxruntime.ai/) to download training wheel. Or build from source:
|
|
|
|
```bash
|
|
export CUDA_HOME=/usr/local/cuda
|
|
export CUDNN_HOME=/usr/local/cuda
|
|
export CUDACXX=$CUDA_HOME/bin/nvcc
|
|
|
|
./build.sh --config RelWithDebInfo --use_cuda --enable_training --build_wheel --skip_tests --cuda_version=11.6 --parallel 8 --use_mpi
|
|
```
|
|
|
|
Install the Python wheel.
|
|
|
|
Configure ORTModule torch cpp extensions (**avoid** doing this in ORT code *repo root directory*):
|
|
|
|
```bash
|
|
python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install
|
|
```
|
|
|
|
|
|
|
|
## 2. Use `ORTModule` to Accelerate Forward/Backward
|
|
|
|
Plug in your `torch.nn.Module` model with `ORTModule` to leverage ONNX Runtime fast training backend.
|
|
|
|
Sample usage as below:
|
|
```diff
|
|
model = build_model()
|
|
|
|
+ from onnxruntime.training.ortmodule import ORTModule
|
|
+ model = ORTModule(model)
|
|
```
|
|
|
|
> It is strongly recommended to wrap model with `ORTModule` before other module wrapper (for example, DeepSpeed, `torch.nn.parallel.DistributedDataParallel`, etc), which is validated in more scenarios.
|
|
|
|
> Be also noticed that, `ORTModule` is **NOT** compatible with `torch.nn.DataParallel` (not recommended to use in PyTorch usage). Please use `torch.nn.parallel.DistributedDataParallel` instead.
|
|
|
|
More options for **developers**.
|
|
```diff
|
|
model = build_model()
|
|
|
|
+ from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
|
|
+ model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name"))
|
|
```
|
|
Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/debug_options.py) for more details.
|
|
|
|
### 2.1 Environment Variables
|
|
|
|
`ORTModule` provides environment variables targeting different use cases.
|
|
|
|
#### ORTMODULE_ONNX_OPSET_VERSION
|
|
|
|
- **Feature Area**: *ORTMODULE/ONNXOPSET*
|
|
- **Description**: By default, as ONNX Runtime released, the ONNX OPSET version to use will be updated periodically. For some customers, they want to stick to fixed OPSET where both performance and accuracy are well validated, this env variable can be used to control that.
|
|
|
|
```bash
|
|
export ORTMODULE_ONNX_OPSET_VERSION=14
|
|
```
|
|
|
|
|
|
#### ORTMODULE_FALLBACK_POLICY
|
|
|
|
- **Feature Area**: *ORTMODULE/FallbackToPytorch*
|
|
- **Description**: By default, if `ORTModule` fails to run the model using ONNX Runtime backend, it will fallback to use PyTorch to continue the training. At some point developers are optimizing the models and doing benchmarking, we want explicitly let ORT backend to run the model. The way we disable the retry:
|
|
```bash
|
|
export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"
|
|
```
|
|
|
|
|
|
#### ORTMODULE_LOG_LEVEL
|
|
|
|
- **Feature Area**: *ORTMODULE/DebugOptions*
|
|
- **Description**: Configure `ORTModule` log level. Defaults to LogLevel.WARNING, can be set one of "VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". The environment variable takes precedence if DebugOptions also sets log_level.
|
|
|
|
#### ORTMODULE_SAVE_ONNX_PATH
|
|
|
|
- **Feature Area**: *ORTMODULE/DebugOptions*
|
|
- **Description**: Configure `ORTModule` to save onnx models. Defaults to False.
|
|
The output directory of the onnx models by default is set to the current working directory. To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be set to the destination directory path.
|
|
|
|
|
|
#### ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT
|
|
|
|
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
|
|
- **Description**: By default `ORTModule` will fail with exception when handling PythonOp export for some `'autograd.Function'`s (One example is torch CheckpointFunction). Set
|
|
this env variable to be `1` to explicitly allow it.
|
|
```bash
|
|
export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1
|
|
```
|
|
|
|
> Take the example of torch.utils.checkpoint.CheckpointFunction, if it is exported as PythonOp, the checkpointed computation may be computed by PyTorch, not ORT. This situation is especially important for big models such as GPT-2 where every few layers are wrapped to do re-computation, large number of computations are done by PyTorch. Currently a failure is reported to notify users it is possible `ORTModule` has less opportunities to optimize further.
|
|
|
|
> On the other hand, if the wrapped computation graph is small, it is reasonable to allow it.
|
|
> Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it.
|
|
|
|
|
|
#### ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT
|
|
|
|
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
|
|
- **Description**: By default, all torch.autograd.Function classes will be exported to ORT PythonOp. There are some cases where you might consider disable it. For example, if you confirmed those torch.autograd.Function classes defined computations that could be inline exported by PyTorch, and it is safe to use the inline exported ONNX graph to train, then you can disable it, as a result, ORT has more opportunities to optimize more.
|
|
```bash
|
|
export ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT=1
|
|
```
|
|
|
|
An alternative to disable without using environment variable:
|
|
|
|
```python
|
|
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
|
|
enable_custom_autograd_support(False)
|
|
```
|
|
|
|
#### ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS
|
|
|
|
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
|
|
- **Description**: By default, this is empty. When user model's setup depends on libraries who might define multiple torch.autograd.Function classes of same name, though their python import paths (e.g. 'namespace') are different, while due to limitation of PyTorch exporter (https://github.com/microsoft/onnx-converters-private/issues/115), ORT backend cannot infer which one to call. So an exception will be thrown for this case.
|
|
Before full qualified name can be got from exporter, this environment variables can be used to specify which torch.autograd.Function classes can be ignored. An example as below, be noted, full qualified name is needed here. If there are multiple classes to be ignored, use comma as the separator.
|
|
|
|
```bash
|
|
export ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS="megatron.fp16.fp16.fused_kernels.GELUFunction"
|
|
```
|
|
|
|
#### ORTMODULE_ENABLE_COMPUTE_OPTIMIZER
|
|
|
|
- **Feature Area**: *ORTMODULE/Optimizations*
|
|
- **Description**: By default, this is enabled then some computation can be saved. This env var can be used for disabling
|
|
the optimization to guarantee exactly same compute with baseline (for example PyTorch, when doing convergence parity
|
|
debugging). Disable it with following command:
|
|
|
|
```bash
|
|
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=0
|
|
```
|
|
|
|
#### ORTMODULE_ENABLE_INPUT_DENSITY_INSPECTOR
|
|
|
|
- **Feature Area**: *ORTMODULE/Runtime inspector*
|
|
- **Description**: By default, this is disabled. This env var can be used for enabling the input data sparsity
|
|
inspection. Training users or our dev could leverage this info for improving perf accordingly. Enable it with following
|
|
command:
|
|
|
|
```bash
|
|
export ORTMODULE_ENABLE_INPUT_DENSITY_INSPECTOR=1
|
|
```
|
|
|
|
### 2.2 Memory Optimization
|
|
|
|
Q: *Want to run a bigger batch size?*
|
|
|
|
Q: *The model training hits OOM, even with minimum required batch size?*
|
|
|
|
Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for how to leverage ORT's recomputation techniques.
|
|
|
|
|
|
## 3. Use `FusedAdam` to Accelerate Parameter Update
|
|
|
|
Parameter update is done by optimizers (for example AdamW) with many elementwise operations. `FusedAdam` launches the elementwise update kernels with multi-tensor apply, allowing batches of gradients applied to corresponding parameters for each time kernel launch.
|
|
|
|
Here is a sample switch from torch `AdamW` optimizer to `FusedAdam`.
|
|
|
|
```diff
|
|
model = build_model()
|
|
|
|
- optimizer = AdamW(model.parameters(), lr=1)
|
|
+ from onnxruntime.training.optim import FusedAdam
|
|
+ optimizer = FusedAdam(model.parameters(), lr=1)
|
|
|
|
```
|
|
|
|
Check [FusedAdam implementation](../orttraining/orttraining/python/training/optim/fused_adam.py) for more details.
|
|
|
|
## 4. Use `FP16_Optimizer` to Complement DeepSpeed/APEX
|
|
|
|
If user models utilize DeepSpeed or Apex libraries, ORT's `FP16_Optimizer` can be used to complement some inefficiencies introduced by them.
|
|
|
|
Use `FP16_Optimizer` with DeepSpeed ZeRO Optimizer:
|
|
|
|
```diff
|
|
optimizer = AdamW(model.parameters(), lr=1)
|
|
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
args=args,
|
|
lr_scheduler=lr_scheduler,
|
|
mpu=mpu,
|
|
dist_init_required=False)
|
|
|
|
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
|
|
+ optimizer = FP16_Optimizer(optimizer)
|
|
|
|
```
|
|
|
|
Use `FP16_Optimizer` with Apex Optimizer:
|
|
```diff
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
|
|
|
|
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
|
|
+ optimizer = ORT_FP16_Optimizer(optimizer)
|
|
|
|
```
|
|
|
|
Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training/optim/fp16_optimizer.py) for more details.
|
|
|
|
|
|
## 5. Putting All Together `ORTModule` + `FusedAdam` + `FP16_Optimizer`
|
|
|
|
```diff
|
|
model = build_model()
|
|
|
|
+ from onnxruntime.training.ortmodule import ORTModule
|
|
+ model = ORTModule(model)
|
|
|
|
- optimizer = AdamW(model.parameters(), lr=1)
|
|
+ from onnxruntime.training.optim import FusedAdam
|
|
+ optimizer = FusedAdam(model.parameters(), lr=1)
|
|
|
|
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
args=args,
|
|
lr_scheduler=lr_scheduler,
|
|
mpu=mpu,
|
|
dist_init_required=False)
|
|
|
|
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
|
|
+ optimizer = FP16_Optimizer(optimizer)
|
|
|
|
```
|
|
|
|
## 6. One More Thing - `LoadBalancingDistributedBatchSampler`
|
|
|
|
`LoadBalancingDistributedBatchSampler` balances the data load across workers based on the sample's complexity.
|
|
This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from **straggler problem**. In such scenarios, the complexity function could be defined to return the length of the input sample sequence. The usage is similar to `torch.utils.data.DistributedSampler`, where each process loads a subset of the original dataset that is exclusive to it.
|
|
|
|
A sample shown below:
|
|
```python
|
|
from onnxruntime.training.utils.data import LoadBalancingDistributedSampler, \
|
|
LoadBalancingDistributedBatchSampler
|
|
sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn)
|
|
batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn)
|
|
loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
|
|
for epoch in range(start_epoch, n_epochs):
|
|
batch_sampler.set_epoch(epoch)
|
|
train(loader)
|
|
```
|
|
|
|
Check [LoadBalancingDistributedBatchSampler implementation](../orttraining/orttraining/python/training/utils/data/sampler.py) for more details.
|