onnxruntime/docs/ORTModule_Training_Guidelines.md

495 lines
20 KiB
Markdown
Raw Normal View History

# ONNX Runtime Training Guidelines
## 1. Installation and Configuration
Be noted: this mainly demonstrates set up steps for development, check [Torch-ORT](https://github.com/pytorch/ort) for end user set up experience.
Refer [https://onnxruntime.ai/](https://onnxruntime.ai/) to download training wheel. Or build from source:
```bash
export CUDA_HOME=/usr/local/cuda
export CUDNN_HOME=/usr/local/cuda
export CUDACXX=$CUDA_HOME/bin/nvcc
./build.sh --config RelWithDebInfo --use_cuda --enable_training --build_wheel --skip_tests --cuda_version=11.8 --parallel 8 --use_mpi
```
Install the Python wheel.
Configure ORTModule torch cpp extensions (**avoid** doing this in ORT code *repo root directory*):
```bash
python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install
```
## 2. Use `ORTModule` to Accelerate Forward/Backward
Plug in your `torch.nn.Module` model with `ORTModule` to leverage ONNX Runtime fast training backend.
Sample usage as below:
```diff
model = build_model()
+ from onnxruntime.training.ortmodule import ORTModule
+ model = ORTModule(model)
```
> It is strongly recommended to wrap model with `ORTModule` before other module wrapper (for example, DeepSpeed, `torch.nn.parallel.DistributedDataParallel`, etc), which is validated in more scenarios.
> Be also noticed that, `ORTModule` is **NOT** compatible with `torch.nn.DataParallel` (not recommended to use in PyTorch usage). Please use `torch.nn.parallel.DistributedDataParallel` instead.
More options for **developers**.
```diff
model = build_model()
+ from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
+ model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name"))
```
Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/options.py) for more details.
Tune ORTModule logging experience a bit (#18298) ### Tune logging experience a bit After last time we update the ORTModule log experience, we found few issues: 1. `INFO` level output too many things, including PyTorch exporter verbose logs (tracing graphs) on every ranks. On this level, we only want to - Output a little bit more information to Users than `WARNING` level, for example the memory recomputation recommendations or other not-fully-ready features. - Output a little bit more information for a quick diagnostic, collected on rank-0 only. 2. ONNX Runtime logging filter during graph build, session init sometimes will hide the issues (for example segement fault), there is no useful information in `WARNING`/`INFO` for users to report to us. This is not good! 3. Some of our devs like using `pdb` to debug Python code, but if we add `import pdb; pdb.set_trace()` in models' code might hang when they use `INFO` or `WARNING`, where exporter happens and all output got redirected due to log filtering. The only workaround is to switch to VERBOSE, which output toooooooooooo many logs. The corresponding changes proposed here are: 1. For `INFO` logging, - We only logs rank-0. - We restricted the ORT backend logging level to be WARNING in this case, because ORT backend code output way too many logs that should be under verbose, while we cannot guarantee we can get them cleaned up immediately once they are added. - We output the PyTorch exporter verbose log (including tracing graph), which is useful for a quick diagnostic when an issue happens. 2. Remove all logging filtering on ORT backend, then the segment fault issue details will not be hidden once it happens again. 3. Introduced a `DEVINFO` logging, - Log logs on all ranks - Log ORT backend logging level INFO - PyTorch exporter logging filtering are all turned OFF (to unblock the pdb debugging). 4. Currently, to use Memory Optimizer, need use DEVINFO (which will output ORT backend INFO log). So update memory optimizer document to reflect this. https://github.com/microsoft/onnxruntime/pull/17481 will update the requirement back to INFO for show memory optimization infos. You can check https://github.com/microsoft/onnxruntime/blob/pengwa/devinfo_level/docs/ORTModule_Training_Guidelines.md#log-level-explanations for a better view of different log levels. This PR also extract some changes from a bigger one https://github.com/microsoft/onnxruntime/pull/17481, to reduce its complexity for review. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
2023-11-08 09:42:50 +00:00
#### Log Level Explanations
<table>
<tr>
<th style="width:20%">Log Level</th>
<th style="width:80%">Description</th>
</tr>
<tr>
<td>
`FATAL` | `ERROR` | `WARNING` (For Users)
<sup>`WARNING` is the default and recommended level for
<br>users.</sup>
</td>
<td>
- ONNX Runtime backend log level - `FATAL` | `ERROR` | `WARNING`.
- ORTModule log level - `FATAL` | `ERROR` | `WARNING`.
- Rank-0 log filtering is `ON` (e.g. logging on rank-0-only).
- PyTorch exporter export logs filtering is `ON`.
- PyTorch exporter verbose logs (including tracing graph) filtering is `ON`.
</td>
</tr>
<tr>
<td>
`INFO` (For Users | ORT Developers)
<sup>`INFO` is used for collecting experimental
<br>feature stats, or a little bit more error messages.</sup>
</td>
<td>
- ONNX Runtime backend log level - `WARNING`.
- ORTModule log level - `INFO`.
- Rank-0 log filtering is `ON` (e.g. logging on rank-0-only).
- PyTorch exporter export logs filtering is `ON`.
- PyTorch exporter verbose logs (including tracing graph) filtering is `OFF`.
</td>
</tr>
<tr>
<td>
`DEVINFO` (For ORT Developers)
<sup>`DEVINFO` is the recommended level for
<br>debugging purposes.</sup>
</td>
<td>
- ONNX Runtime backend log level - `INFO`.
- ORTModule log level - `INFO`.
- Rank-0 log filtering is `OFF` (e.g. logging on all ranks).
- PyTorch exporter export logs filtering is `OFF`.
- PyTorch exporter verbose logs (including tracing graph) filtering is `OFF`.
</td>
</tr>
<tr>
<td>
`VERBOSE` (For ORT Developers)
<sup>`VERBOSE` is the last resort for debugging
<br>hard problems.</sup>
</td>
<td>
- ONNX Runtime backend log level - `VERBOSE`.
- ORTModule log level - `VERBOSE`.
- Rank-0 log filtering is `OFF` (e.g. logging on all ranks).
- PyTorch exporter export logs filtering is `OFF`.
- PyTorch exporter verbose logs (including tracing graph) filtering is `OFF`.
</td>
</tr>
</table>
### 2.1 Environment Variables
`ORTModule` provides environment variables targeting different use cases.
#### ORTMODULE_ONNX_OPSET_VERSION
- **Feature Area**: *ORTMODULE/ONNXOPSET*
- **Description**: By default, as ONNX Runtime released, the ONNX OPSET version to use will be updated periodically. For some customers, they want to stick to fixed OPSET where both performance and accuracy are well validated, this env variable can be used to control that.
```bash
export ORTMODULE_ONNX_OPSET_VERSION=14
```
#### ORTMODULE_FALLBACK_POLICY
- **Feature Area**: *ORTMODULE/FallbackToPytorch*
- **Description**: By default, if `ORTModule` fails to run the model using ONNX Runtime backend, it will fallback to use PyTorch to continue the training. At some point developers are optimizing the models and doing benchmarking, we want explicitly let ORT backend to run the model. The way we disable the retry:
```bash
export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"
```
#### ORTMODULE_LOG_LEVEL
- **Feature Area**: *ORTMODULE/DebugOptions*
- **Description**: Configure `ORTModule` log level. Defaults to LogLevel.WARNING, can be set one of "VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". The environment variable takes precedence if DebugOptions also sets log_level.
#### ORTMODULE_SAVE_ONNX_PATH
- **Feature Area**: *ORTMODULE/DebugOptions*
- **Description**: Configure `ORTModule` to save onnx models. Defaults to False.
The output directory of the onnx models by default is set to the current working directory. To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be set to the destination directory path.
#### ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
- **Description**: By default `ORTModule` will fail with exception when handling PythonOp export for some `'autograd.Function'`s (One example is torch CheckpointFunction). Set
this env variable to be `1` to explicitly allow it.
```bash
export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1
```
> Take the example of torch.utils.checkpoint.CheckpointFunction, if it is exported as PythonOp, the checkpointed computation may be computed by PyTorch, not ORT. This situation is especially important for big models such as GPT-2 where every few layers are wrapped to do re-computation, large number of computations are done by PyTorch. Currently a failure is reported to notify users it is possible `ORTModule` has less opportunities to optimize further.
> On the other hand, if the wrapped computation graph is small, it is reasonable to allow it.
> Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it.
#### ORTMODULE_ENABLE_CUSTOM_AUTOGRAD
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
- **Description**: By default, all torch.autograd.Function classes will be exported to ORT PythonOp. There are some cases where you might consider disable it. For example, if you confirmed those torch.autograd.Function classes defined computations that could be inline exported by PyTorch, and it is safe to use the inline exported ONNX graph to train, then you can disable it, as a result, ORT has more opportunities to optimize more.
```bash
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=1 # Enable
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=0 # Disable
```
An alternative to disable without using environment variable:
```python
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
enable_custom_autograd_support(False)
```
Optimize computation orders (#13672) ### Optimize computation orders In `Roberta/Electra`, when `ClassificationHead` is used, there is slicing operation on features on sequence_length dimensions, then loss calculations only depend on this sliced data. This is a slicing at axis 1. Before slicing the shape is [batch, sequence_length, hidden], after slicing, it becomes [batch , hidden_stage] We had opportunities to bring this slicing earlier as much as possible, by passing through simple elementwise ops (like Add/Div), or Layernorm/Softmax(if their reduce axis is after the slicing axis), or even MatMul's the left operand (if only it did not affect the last dims). For operators like Reshape/Transpose, it is special since they have either data specified (after slicing we need update), or they have perm specified, which requires the input rank remain unchanged. So for those kinds of operators, we can remain the original rank, but just leave the sliced dim to be 1, after the compute completed, we do a Squeeze. ``` class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): x = features[:, 0, :] # take <s> token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x ``` src\transformers\models\roberta\modeling_roberta.py src\transformers\models\electra\modeling_electra.py #### Benchmark A simple benchmark shows Robeta training latency dropped from 208ms ~ 199ms. 4.5+% reduction. More comprehensive tests are on the way. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
2022-12-22 07:12:52 +00:00
#### ORTMODULE_ENABLE_COMPUTE_OPTIMIZER
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled then some computation can be saved. This env var can be used for disabling
the optimization to guarantee exactly same compute with baseline (for example PyTorch, when doing convergence parity
debugging).
Introduce padding inspector in ORTModule (#14652) ### Introduce padding inspector in ORTModule In some Transformer-based LLM training recipes, high data sparsity is observed due to 1). token padding (to max sequence length), 2). labels contains many ignore_index for calculate loss. This PR introduces a switch to enable data sparsity inspection, which 1). in short term, can inform training users to use techniques like dynamic batching to amortize the issue. 2). in medium and longer term, also helps us (training team) to have better understanding what our training customers' models looks like from perspective of data sparsity (and potentially motivate us to improve with runtime). Here is an example of different data sparsity with same training model arch, same training input, but with different user models. **Low Embed Density, High Label Density Case - Sentence Classification** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/text-classification/run_glue.py --model_name_or_path roberta-large-openai-detector --task_name mnli --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir --output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137 --fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] | | 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] | | 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] | | 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] | | 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] | | 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | <<< ``` **High Embed Density, Low Label Density Case - masked language model** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/language-modeling/run_mlm.py --model_name_or_path bert-base-uncased --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A | | 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A | | 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A | | 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A | | 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A | <<< ``` #### Next Step Let's see how we leverage the data sparsity for improvement. Optimizations on the way around compute optimizer wave 2: > Loss compute flops reduction. > Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
```bash
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=0 # Disable
Introduce padding inspector in ORTModule (#14652) ### Introduce padding inspector in ORTModule In some Transformer-based LLM training recipes, high data sparsity is observed due to 1). token padding (to max sequence length), 2). labels contains many ignore_index for calculate loss. This PR introduces a switch to enable data sparsity inspection, which 1). in short term, can inform training users to use techniques like dynamic batching to amortize the issue. 2). in medium and longer term, also helps us (training team) to have better understanding what our training customers' models looks like from perspective of data sparsity (and potentially motivate us to improve with runtime). Here is an example of different data sparsity with same training model arch, same training input, but with different user models. **Low Embed Density, High Label Density Case - Sentence Classification** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/text-classification/run_glue.py --model_name_or_path roberta-large-openai-detector --task_name mnli --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir --output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137 --fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] | | 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] | | 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] | | 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] | | 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] | | 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | <<< ``` **High Embed Density, Low Label Density Case - masked language model** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/language-modeling/run_mlm.py --model_name_or_path bert-base-uncased --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A | | 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A | | 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A | | 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A | | 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A | <<< ``` #### Next Step Let's see how we leverage the data sparsity for improvement. Optimizations on the way around compute optimizer wave 2: > Loss compute flops reduction. > Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
```
#### ORTMODULE_ENABLE_SPARSE_OPTIMIZER
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the input data sparsity
based performance optimizations, including embedding sparsity and label sparsity.
This optimization is applicable when using optimum, which has an implementation of the ModuleWithLoss class that wraps the HuggingFace Training that allows loss computation inside ONNX Runtime (ORT).
If you're not using optimum but want to implement a similar wrapper in your codebase to compute the loss inside ONNX Runtime (ORT), you can refer to this [Link](ORTModule_ModuleWithLoss_Wrapper.md) for detailed steps and guidelines on how to achieve this.
```bash
export ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_SPARSE_OPTIMIZER=0 # Disable
```
#### ORTMODULE_PRINT_INPUT_DENSITY
Introduce padding inspector in ORTModule (#14652) ### Introduce padding inspector in ORTModule In some Transformer-based LLM training recipes, high data sparsity is observed due to 1). token padding (to max sequence length), 2). labels contains many ignore_index for calculate loss. This PR introduces a switch to enable data sparsity inspection, which 1). in short term, can inform training users to use techniques like dynamic batching to amortize the issue. 2). in medium and longer term, also helps us (training team) to have better understanding what our training customers' models looks like from perspective of data sparsity (and potentially motivate us to improve with runtime). Here is an example of different data sparsity with same training model arch, same training input, but with different user models. **Low Embed Density, High Label Density Case - Sentence Classification** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/text-classification/run_glue.py --model_name_or_path roberta-large-openai-detector --task_name mnli --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir --output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137 --fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] | | 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] | | 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] | | 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] | | 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] | | 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | <<< ``` **High Embed Density, Low Label Density Case - masked language model** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/language-modeling/run_mlm.py --model_name_or_path bert-base-uncased --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A | | 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A | | 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A | | 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A | | 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A | <<< ``` #### Next Step Let's see how we leverage the data sparsity for improvement. Optimizations on the way around compute optimizer wave 2: > Loss compute flops reduction. > Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
- **Feature Area**: *ORTMODULE/RuntimeInspector*
Introduce memory observer for ORTModule (#16213) ### Introduce memory observer for ORTModule To analyze memory usage for ORTModule training, we need collect per-iteration memory footprint in different stages (pre-forward, post-forward, pre-backward, and post-backward). Currently we only collect the data using torch.cuda APIs. The next step is, we could collect the detailed stashed activation list and its percentage within ORT backend, which is beyond this PR. Sample as below: ``` 0/8] step 0 memory (MiB) | phase: pre_forward | allocated: 1866 | max allocated: 1866 | cached: 1874 | max cached: 1874 | inactive: 8 | max inactive: 8 [0/8] step 0 memory (MiB) | phase: post_forward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405 [0/8] step 0 memory (MiB) | phase: pre_backward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405 [0/8] step 0 memory (MiB) | phase: post_backward | allocated: 2932 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 6158 | max inactive: 6158 0%|█ | 1/200 [00:26<1:26:18, 26.02s/it] [0/8] step 1 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 2454 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165 1%|██ | 2/200 [00:26<36:47, 11.15s/it] [0/8] step 2 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2454 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165 ```
2023-06-15 07:45:36 +00:00
- **Description**: By default, this is disabled. This env var can be used for printing the input data sparsity
inspection results to standard outputs.
Introduce padding inspector in ORTModule (#14652) ### Introduce padding inspector in ORTModule In some Transformer-based LLM training recipes, high data sparsity is observed due to 1). token padding (to max sequence length), 2). labels contains many ignore_index for calculate loss. This PR introduces a switch to enable data sparsity inspection, which 1). in short term, can inform training users to use techniques like dynamic batching to amortize the issue. 2). in medium and longer term, also helps us (training team) to have better understanding what our training customers' models looks like from perspective of data sparsity (and potentially motivate us to improve with runtime). Here is an example of different data sparsity with same training model arch, same training input, but with different user models. **Low Embed Density, High Label Density Case - Sentence Classification** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/text-classification/run_glue.py --model_name_or_path roberta-large-openai-detector --task_name mnli --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir --output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137 --fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] | | 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] | | 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] | | 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] | | 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] | | 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | <<< ``` **High Embed Density, Low Label Density Case - masked language model** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/language-modeling/run_mlm.py --model_name_or_path bert-base-uncased --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A | | 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A | | 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A | | 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A | | 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A | <<< ``` #### Next Step Let's see how we leverage the data sparsity for improvement. Optimizations on the way around compute optimizer wave 2: > Loss compute flops reduction. > Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
```bash
export ORTMODULE_PRINT_INPUT_DENSITY=1 # Enable
export ORTMODULE_PRINT_INPUT_DENSITY=0 # Disable
Introduce padding inspector in ORTModule (#14652) ### Introduce padding inspector in ORTModule In some Transformer-based LLM training recipes, high data sparsity is observed due to 1). token padding (to max sequence length), 2). labels contains many ignore_index for calculate loss. This PR introduces a switch to enable data sparsity inspection, which 1). in short term, can inform training users to use techniques like dynamic batching to amortize the issue. 2). in medium and longer term, also helps us (training team) to have better understanding what our training customers' models looks like from perspective of data sparsity (and potentially motivate us to improve with runtime). Here is an example of different data sparsity with same training model arch, same training input, but with different user models. **Low Embed Density, High Label Density Case - Sentence Classification** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/text-classification/run_glue.py --model_name_or_path roberta-large-openai-detector --task_name mnli --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir --output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137 --fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] | | 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] | | 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] | | 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] | | 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] | | 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | <<< ``` **High Embed Density, Low Label Density Case - masked language model** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/language-modeling/run_mlm.py --model_name_or_path bert-base-uncased --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A | | 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A | | 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A | | 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A | | 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A | <<< ``` #### Next Step Let's see how we leverage the data sparsity for improvement. Optimizations on the way around compute optimizer wave 2: > Loss compute flops reduction. > Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 10:36:08 +00:00
```
Optimize computation orders (#13672) ### Optimize computation orders In `Roberta/Electra`, when `ClassificationHead` is used, there is slicing operation on features on sequence_length dimensions, then loss calculations only depend on this sliced data. This is a slicing at axis 1. Before slicing the shape is [batch, sequence_length, hidden], after slicing, it becomes [batch , hidden_stage] We had opportunities to bring this slicing earlier as much as possible, by passing through simple elementwise ops (like Add/Div), or Layernorm/Softmax(if their reduce axis is after the slicing axis), or even MatMul's the left operand (if only it did not affect the last dims). For operators like Reshape/Transpose, it is special since they have either data specified (after slicing we need update), or they have perm specified, which requires the input rank remain unchanged. So for those kinds of operators, we can remain the original rank, but just leave the sliced dim to be 1, after the compute completed, we do a Squeeze. ``` class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): x = features[:, 0, :] # take <s> token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x ``` src\transformers\models\roberta\modeling_roberta.py src\transformers\models\electra\modeling_electra.py #### Benchmark A simple benchmark shows Robeta training latency dropped from 208ms ~ 199ms. 4.5+% reduction. More comprehensive tests are on the way. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
2022-12-22 07:12:52 +00:00
Introduce memory observer for ORTModule (#16213) ### Introduce memory observer for ORTModule To analyze memory usage for ORTModule training, we need collect per-iteration memory footprint in different stages (pre-forward, post-forward, pre-backward, and post-backward). Currently we only collect the data using torch.cuda APIs. The next step is, we could collect the detailed stashed activation list and its percentage within ORT backend, which is beyond this PR. Sample as below: ``` 0/8] step 0 memory (MiB) | phase: pre_forward | allocated: 1866 | max allocated: 1866 | cached: 1874 | max cached: 1874 | inactive: 8 | max inactive: 8 [0/8] step 0 memory (MiB) | phase: post_forward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405 [0/8] step 0 memory (MiB) | phase: pre_backward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405 [0/8] step 0 memory (MiB) | phase: post_backward | allocated: 2932 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 6158 | max inactive: 6158 0%|█ | 1/200 [00:26<1:26:18, 26.02s/it] [0/8] step 1 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 2454 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165 1%|██ | 2/200 [00:26<36:47, 11.15s/it] [0/8] step 2 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2454 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165 ```
2023-06-15 07:45:36 +00:00
#### ORTMODULE_PRINT_MEMORY_STATS
- **Feature Area**: *ORTMODULE/RuntimeInspector*
- **Description**: By default, this is disabled. This env var can be used for printing the memory inspection results
to standard outputs.
```bash
export ORTMODULE_PRINT_MEMORY_STATS=1 # Enable
export ORTMODULE_PRINT_MEMORY_STATS=0 # Disable
```
#### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input
data sparsity based performance optimizations.
```bash
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER=0 # Disable
```
#### ORTMODULE_CACHE_DIR
- **Feature Area**: *ORTMODULE/RuntimeOptions*
- **Description**: By default, this is disabled. This env vars can be used to cache the exported model for future runs. This optimization is intended to reduce experimentation time by re-using the PyTorch->ONNX exported model architecture when available.
```bash
export ORTMODULE_CACHE_DIR="/path/to/cache_dir" # Enable
unset ORTMODULE_CACHE_DIR # Disable
```
#### ORTMODULE_USE_EFFICIENT_ATTENTION
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.
```bash
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
```
#### ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the module deep copy when preparing output data which will be used by ONNX export.
A classical usage of disabling the deep copy: when the deep copy before module export bring the memory peak, then we should disable it and have a try.
```bash
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1 # Enable
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable
```
#### ORTMODULE_MEMORY_OPT_LEVEL
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details.
```bash
export ORTMODULE_MEMORY_OPT_LEVEL=0
```
### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes.
```bash
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 # Enable
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
```
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*
Q: *The model training hits OOM, even with minimum required batch size?*
Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for how to leverage ORT's recomputation techniques.
## 3. Use `FusedAdam` to Accelerate Parameter Update
Parameter update is done by optimizers (for example AdamW) with many elementwise operations. `FusedAdam` launches the elementwise update kernels with multi-tensor apply, allowing batches of gradients applied to corresponding parameters for each time kernel launch.
Here is a sample switch from torch `AdamW` optimizer to `FusedAdam`.
```diff
model = build_model()
- optimizer = AdamW(model.parameters(), lr=1)
+ from onnxruntime.training.optim import FusedAdam
+ optimizer = FusedAdam(model.parameters(), lr=1)
```
Check [FusedAdam implementation](../orttraining/orttraining/python/training/optim/fused_adam.py) for more details.
## 4. Use `FP16_Optimizer` to Complement DeepSpeed/APEX
If user models utilize DeepSpeed or Apex libraries, ORT's `FP16_Optimizer` can be used to complement some inefficiencies introduced by them.
Use `FP16_Optimizer` with DeepSpeed ZeRO Optimizer:
```diff
optimizer = AdamW(model.parameters(), lr=1)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=False)
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
+ optimizer = FP16_Optimizer(optimizer)
```
Use `FP16_Optimizer` with Apex Optimizer:
```diff
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
+ optimizer = ORT_FP16_Optimizer(optimizer)
```
Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training/optim/fp16_optimizer.py) for more details.
## 5. Putting All Together `ORTModule` + `FusedAdam` + `FP16_Optimizer`
```diff
model = build_model()
+ from onnxruntime.training.ortmodule import ORTModule
+ model = ORTModule(model)
- optimizer = AdamW(model.parameters(), lr=1)
+ from onnxruntime.training.optim import FusedAdam
+ optimizer = FusedAdam(model.parameters(), lr=1)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=False)
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
+ optimizer = FP16_Optimizer(optimizer)
```
## 6. Use OpenAI Triton to Compute ONNX Sub-graph
`ORTModule` provides a way to switch to OpenAI Triton for executing some Ops to further accelerate training.
### 6.1 Environment Variables
#### ORTMODULE_USE_TRITON
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling Triton optimization.
```bash
export ORTMODULE_USE_TRITON=1
```
#### ORTMODULE_TRITON_CONFIG_FILE
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: Triton codegen currently supported some Ops such as some elementwise Ops and some reduction Ops. If Triton optimization is enabled, all these supported Ops will be optimized by default if possible. User can provide a customized JSON config file to control which Ops to optimize and how to optimize them. Below is a sample of config JSON. For each Op, Opset version list and domain is needed. Currently "conditions" field can be used to control axis/axes attribute or input, by specify the real value, or "single" means it contains only one dimension, or "constant" means it must be constant tensor. Save the JSON as a file somewhere and assign its path to below env variable to enable the customized config.
```json
{
"ops": {
"Add": {"versions": [13, 14]},
"Sub": {"versions": [13, 14]},
"Identity": {"versions": [13], "is_no_op": True},
"ReduceSum": {"versions": [13], "conditions": {"axes": "[-1]"}},
"Softmax": {"versions": [13]},
"SoftmaxGrad_13": {"domain": "com.microsoft", "versions": [1]}
},
"initializer": "scalar",
"min_nodes": 2
}
```
```bash
export ORTMODULE_TRITON_CONFIG_FILE=triton_config.json
```
#### ORTMODULE_ENABLE_TUNING
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling online Op tuning for those Ops that have multiple implementations on target EP.
```bash
export ORTMODULE_ENABLE_TUNING=1
```
#### ORTMODULE_MAX_TUNING_DURATION_MS
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: When `ORTMODULE_ENABLE_TUNING` is enabled, this env var can be used to set max tuning duration in ms to avoid long tuning time.
```bash
export ORTMODULE_MAX_TUNING_DURATION_MS=9999
```
#### ORTMODULE_TUNING_RESULTS_PATH
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: When `ORTMODULE_ENABLE_TUNING` is enabled, this env var can be used to specify where the online Op tuning results be saved for later use. By default the results will not be saved. When `ORTMODULE_ENABLE_TUNING` is NOT enabled, this env var can be used to specify where Op tuning results can be fetched as offline tuning results.
```bash
export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results
```
#### ORTMODULE_USE_FLASH_ATTENTION
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.
```bash
export ORTMODULE_USE_FLASH_ATTENTION=1
```
#### ORTMODULE_TRITON_DEBUG
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling Triton debug mode. All original and processed sub-graphs and corresponding generated Triton codes will be saved into a triton_debug folder under working directory.
```bash
export ORTMODULE_TRITON_DEBUG=1
```
## 7. One More Thing - `LoadBalancingDistributedBatchSampler`
`LoadBalancingDistributedBatchSampler` balances the data load across workers based on the sample's complexity.
This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from **straggler problem**. In such scenarios, the complexity function could be defined to return the length of the input sample sequence. The usage is similar to `torch.utils.data.DistributedSampler`, where each process loads a subset of the original dataset that is exclusive to it.
A sample shown below:
```python
from onnxruntime.training.utils.data import LoadBalancingDistributedSampler, \
LoadBalancingDistributedBatchSampler
sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn)
batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn)
loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
for epoch in range(start_epoch, n_epochs):
batch_sampler.set_epoch(epoch)
train(loader)
```
Check [LoadBalancingDistributedBatchSampler implementation](../orttraining/orttraining/python/training/utils/data/sampler.py) for more details.