Add guidelines for ORTModule (#13553)

### Add guidelines for ORTModule

As title.

Feel free to let me know if I missed something. 

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2022-11-04 19:42:10 +08:00 committed by GitHub
parent 433f262dd5
commit ab9ac2acc4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 250 additions and 18 deletions

View file

@ -2,25 +2,25 @@
## Introduction
ONNX Runtime Training provides a capability trading node/subgraph recomputations for better memory efficiency.
Specifically, a list of recomputable operators is pre-defined, with which memory optimizer graph transformer will iterate the graph to find all recomputable subgraph candidates.
ONNX Runtime Training provides a capability trading node/subgraph re-computations for better memory efficiency.
Specifically, a list of re-computable operators is pre-defined, with which memory optimizer graph transformer will iterate the graph to find all re-computable subgraph candidates.
When training with ORTModule, by default, the graph transformer will scan the execution graph to find all eligible subgraphs to recompute, along with sizes that can save. Users can pick up some of the subgraphs to enable them by environment variables.
When training with `ORTModule`, by default, the graph transformer will scan the execution graph to find all eligible subgraphs to recompute, along with sizes that can be saved. Users can pick up some of the subgraphs to enable by environment variables.
## When memory optimizer can help?
Classical scenarios include:
- ORTModule run a model with batch size B (for example 2^N), the memory bandwidth and compute are not fully saturated, while it hits OOM to run a bigger batch size (for example 2^(N+1)).
- `ORTModule` runs a model with batch size B (for example 2^N), the memory bandwidth and compute are not fully saturated, while it hits OOM to run a bigger batch size (for example 2^(N+1)).
- For big models, ORTModule fails to run the minimum allowed batch size, so performance can be compromised for a successful run.
- For big models, `ORTModule` fails to run the minimum allowed batch size, so performance can be compromised for a successful run.
Not all models and recipes need this optimizer technique. Imagine if your training recipe is using a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6.
Not all models and recipes need this optimizer technique. Imagine if your training recipe uses a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6.
## Quick trial
1. Make sure ONNX Runtime training wheel is installed and correctly configured.
2. Integrate models using ORTModule, be noted log_level should be equal or lower than INFO.
2. Integrate models using `ORTModule`, be noted log_level should be equal or lower than INFO.
> ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.INFO))
3. Run the training as usual and redirect all outputs into log file; then stop it after training few steps.
4. Check the logging file, search "Summary", you could possibly find something like this:
@ -48,8 +48,8 @@ Not all models and recipes need this optimizer technique. Imagine if your traini
--------------------------------
=================================
```
5. As shown above, 'Subgraph' shows 1) a string representative for a recomputable subgraph; and 2) current status of memory optimization. All are disabled for recompute in this case.
6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In this sample, 12 FastGelu related subgraphs are allowed to recompute.
5. As shown above, 'Subgraph' shows 1) a string representative for a re-computable subgraph; and 2) current status of memory optimization. All are disabled for recompute in this case.
6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, 12 FastGelu related subgraphs are allowed to recompute.
`FastGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `12` means the initial 12 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed.
```
export ORTMODULE_MEMORY_OPT_CONFIG="FastGelu+:1:12"

View file

@ -0,0 +1,229 @@
# ONNX Runtime Training Guidelines
## 1. Installation and Configuration
Be noted: this mainly demonstrates set up steps for development, check [Torch-ORT](https://github.com/pytorch/ort) for end user set up experience.
Refer [https://onnxruntime.ai/](https://onnxruntime.ai/) to download training wheel. Or build from source:
```bash
export CUDA_HOME=/usr/local/cuda
export CUDNN_HOME=/usr/local/cuda
export CUDACXX=$CUDA_HOME/bin/nvcc
./build.sh --config RelWithDebInfo --use_cuda --enable_training --build_wheel --skip_tests --cuda_version=11.6 --parallel 8 --use_mpi --enable_training_torch_interop
```
Install the Python wheel.
Configure ORTModule torch cpp extensions (**avoid** doing this in ORT code *repo root directory*):
```bash
python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install
```
## 2. Use `ORTModule` to Accelerate Forward/Backward
Plug in your `torch.nn.Module` model with `ORTModule` to leverage ONNX Runtime fast training backend.
Sample usage as below:
```diff
model = build_model()
+ from onnxruntime.training import ORTModule
+ model = ORTModule(model)
```
> It is strongly recommended to wrap model with `ORTModule` before other module wrapper (for example, DeepSpeed, `torch.nn.parallel.DistributedDataParallel`, etc), which is validated in more scenarios.
> Be also noticed that, `ORTModule` is **NOT** compatible with `torch.nn.DataParallel` (not recommended to use in PyTorch usage). Please use `torch.nn.parallel.DistributedDataParallel` instead.
More options for **developers**.
```diff
model = build_model()
+ from onnxruntime.training import ORTModule, DebugOptions, LogLevel
+ model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name"))
```
Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/debug_options.py) for more details.
### 2.1 Environment Variables
`ORTModule` provides environment variables targeting different use cases.
#### ORTMODULE_ONNX_OPSET_VERSION
- **Feature Area**: *ORTMODULE/ONNXOPSET*
- **Description**: By default, as ONNX Runtime released, the ONNX OPSET version to use will be updated periodically. For some customers, they want to stick to fixed OPSET where both performance and accuracy are well validated, this env variable can be used to control that.
```bash
export ORTMODULE_ONNX_OPSET_VERSION=14
```
#### ORTMODULE_FALLBACK_POLICY
- **Feature Area**: *ORTMODULE/FallbackToPytorch*
- **Description**: By default, if `ORTModule` fails to run the model using ONNX Runtime backend, it will fallback to use PyTorch to continue the training. At some point developers are optimizing the models and doing benchmarking, we want explicitly let ORT backend to run the model. The way we disable the retry:
```bash
export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"
```
#### ORTMODULE_LOG_LEVEL
- **Feature Area**: *ORTMODULE/DebugOptions*
- **Description**: Configure `ORTModule` log level. Defaults to LogLevel.WARNING, can be set one of "VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". The environment variable takes precedence if DebugOptions also sets log_level.
#### ORTMODULE_SAVE_ONNX_PATH
- **Feature Area**: *ORTMODULE/DebugOptions*
- **Description**: Configure `ORTModule` to save onnx models. Defaults to False.
The output directory of the onnx models by default is set to the current working directory. To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be set to the destination directory path.
#### ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
- **Description**: By default `ORTModule` will fail with exception when handling PythonOp export for some `'autograd.Function'`s (One example is torch CheckpointFunction). Set
this env variable to be `1` to explicitly allow it.
```bash
export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1
```
> Take the example of torch.utils.checkpoint.CheckpointFunction, if it is exported as PythonOp, the checkpointed computation may be computed by PyTorch, not ORT. This situation is especially important for big models such as GPT-2 where every few layers are wrapped to do re-computation, large number of computations are done by PyTorch. Currently a failure is reported to notify users it is possible `ORTModule` has less opportunities to optimize further.
> On the other hand, if the wrapped computation graph is small, it is reasonable to allow it.
> Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it.
#### ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
- **Description**: By default, all torch.autograd.Function classes will be exported to ORT PythonOp. There are some cases where you might consider disable it. For example, if you confirmed those torch.autograd.Function classes defined computations that could be inline exported by PyTorch, and it is safe to use the inline exported ONNX graph to train, then you can disable it, as a result, ORT has more opportunities to optimize more.
```bash
export ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT=1
```
An alternative to disable without using environment variable:
```python
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
enable_custom_autograd_support(False)
```
#### ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
- **Description**: By default, this is empty. When user model's setup depends on libraries who might define multiple torch.autograd.Function classes of same name, though their python import paths (e.g. 'namespace') are different, while due to limitation of PyTorch exporter (https://github.com/microsoft/onnx-converters-private/issues/115), ORT backend cannot infer which one to call. So an exception will be thrown for this case.
Before full qualified name can be got from exporter, this environment variables can be used to specify which torch.autograd.Function classes can be ignored. An example as below, be noted, full qualified name is needed here. If there are multiple classes to be ignored, use comma as the separator.
```bash
export ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS="megatron.fp16.fp16.fused_kernels.GELUFunction"
```
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*
Q: *The model training hits OOM, even with minimum required batch size?*
Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for how to leverage ORT's recomputation techniques.
## 3. Use `FusedAdam` to Accelerate Parameter Update
Parameter update is done by optimizers (for example AdamW) with many elementwise operations. `FusedAdam` launches the elementwise update kernels with multi-tensor apply, allowing batches of gradients applied to corresponding parameters for each time kernel launch.
Here is a sample switch from torch `AdamW` optimizer to `FusedAdam`.
```diff
model = build_model()
- optimizer = AdamW(model.parameters(), lr=1)
+ from onnxruntime.training.optim import FusedAdam
+ optimizer = FusedAdam(model.parameters(), lr=1)
```
Check [FusedAdam implementation](../orttraining/orttraining/python/training/optim/fused_adam.py) for more details.
## 4. Use `FP16_Optimizer` to Complement DeepSpeed/APEX
If user models utilize DeepSpeed or Apex libraries, ORT's `FP16_Optimizer` can be used to complement some inefficiencies introduced by them.
Use `FP16_Optimizer` with DeepSpeed ZeRO Optimizer:
```diff
optimizer = AdamW(model.parameters(), lr=1)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=False)
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
+ optimizer = FP16_Optimizer(optimizer)
```
Use `FP16_Optimizer` with Apex Optimizer:
```diff
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
+ optimizer = ORT_FP16_Optimizer(optimizer)
```
Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training/optim/fp16_optimizer.py) for more details.
## 5. Putting All Together `ORTModule` + `FusedAdam` + `FP16_Optimizer`
```diff
model = build_model()
+ from onnxruntime.training import ORTModule
+ model = ORTModule(model)
- optimizer = AdamW(model.parameters(), lr=1)
+ from onnxruntime.training.optim import FusedAdam
+ optimizer = FusedAdam(model.parameters(), lr=1)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=False)
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
+ optimizer = FP16_Optimizer(optimizer)
```
## 6. One More Thing - `LoadBalancingDistributedBatchSampler`
`LoadBalancingDistributedBatchSampler` balances the data load across workers based on the sample's complexity.
This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from **straggler problem**. In such scenarios, the complexity function could be defined to return the length of the input sample sequence. The usage is similar to `torch.utils.data.DistributedSampler`, where each process loads a subset of the original dataset that is exclusive to it.
A sample shown below:
```python
from onnxruntime.training.utils.data import LoadBalancingDistributedSampler, \
LoadBalancingDistributedBatchSampler
sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn)
batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn)
loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
for epoch in range(start_epoch, n_epochs):
batch_sampler.set_epoch(epoch)
train(loader)
```
Check [LoadBalancingDistributedBatchSampler implementation](../orttraining/orttraining/python/training/utils/data/sampler.py) for more details.

View file

@ -490,7 +490,7 @@ void MemoryOptimizer::RegisterAllowedRecomputeOps() {
}
}
Status MemoryOptimizer::SelectRecomputeSubgraph(const Node& node,
Status MemoryOptimizer::SelectRecomputeSubgraph(const Node& entry_node,
const InlinedVector<size_t>& node_output_index_candidates,
const ActivationUsedMap& fw_op_output_arg_used_map,
const InlinedHashMap<NodeIndex, size_t>&
@ -501,12 +501,12 @@ Status MemoryOptimizer::SelectRecomputeSubgraph(const Node& node,
bool& can_compromise_stashed_activation) const {
can_compromise_stashed_activation = false;
LOGS(logger, VERBOSE) << "Enter SelectRecomputeSubgraph for Node " << node.Name() << "(" << node.OpType() << ")";
LOGS(logger, VERBOSE) << "Enter SelectRecomputeSubgraph for Node " << entry_node.Name() << "(" << entry_node.OpType() << ")";
nodes.clear();
std::deque<NodeOutputPort> q;
for (auto output_index : node_output_index_candidates) {
q.push_back(NodeOutputPort(&node, static_cast<int>(output_index)));
q.push_back(NodeOutputPort(&entry_node, static_cast<int>(output_index)));
}
bool early_stop = false;
@ -564,14 +564,16 @@ Status MemoryOptimizer::SelectRecomputeSubgraph(const Node& node,
if (op_recompute_config_it == recomputable_op_type_to_input_arg_index_map_.end()) {
if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) {
LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in "
<< "recompute op list, but its output [" << cur_output_arg_name
<< "] is used in backward, we don't need trace bottom-up further";
<< "recompute op list, but its output [" << cur_output_arg_name << "] is used in "
<< "backward, we don't need trace bottom-up further. Entry node: "
<< entry_node.Name() << "(" << entry_node.OpType() << ")";
continue;
} else {
early_stop = true;
LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in "
<< "recompute op list, and its output [" << cur_output_arg_name
<< "] does not exist in backward, search terminates.";
<< "] does not exist in backward, search terminates. Entry node: "
<< entry_node.Name() << "(" << entry_node.OpType() << ")";
break;
}
}
@ -579,7 +581,8 @@ Status MemoryOptimizer::SelectRecomputeSubgraph(const Node& node,
if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) {
LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") "
<< "is in recompute op list, while its output [" << cur_output_arg_name
<< "] is used in backward, we don't need trace bottom-up further";
<< "] is used in backward, we don't need trace bottom-up further. Entry node: "
<< entry_node.Name() << "(" << entry_node.OpType() << ")";
continue;
}
}

View file

@ -67,8 +67,8 @@ class MemoryOptimizer : public GraphTransformer {
int total_frequency{0}; // The occurrence of this subgraph pattern in the graph.
int applied_count{0}; // The number of times this subgraph pattern has been really applied in this transformer.
int skip_count{0}; // The number of times this subgraph instances will skipped in reversed topological order.
float saving_ratio{1.0f};
int skip_count{0}; // The number of times this subgraph instance has been skipped in reversed topological order.
float saving_ratio{1.0f}; // For compromised memory saving, the ratio of memory saving.
};
/**