> It is strongly recommended to wrap model with `ORTModule` before other module wrapper (for example, DeepSpeed, `torch.nn.parallel.DistributedDataParallel`, etc), which is validated in more scenarios.
> Be also noticed that, `ORTModule` is **NOT** compatible with `torch.nn.DataParallel` (not recommended to use in PyTorch usage). Please use `torch.nn.parallel.DistributedDataParallel` instead.
+ model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name"))
```
Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/debug_options.py) for more details.
### 2.1 Environment Variables
`ORTModule` provides environment variables targeting different use cases.
#### ORTMODULE_ONNX_OPSET_VERSION
- **Feature Area**: *ORTMODULE/ONNXOPSET*
- **Description**: By default, as ONNX Runtime released, the ONNX OPSET version to use will be updated periodically. For some customers, they want to stick to fixed OPSET where both performance and accuracy are well validated, this env variable can be used to control that.
```bash
export ORTMODULE_ONNX_OPSET_VERSION=14
```
#### ORTMODULE_FALLBACK_POLICY
- **Feature Area**: *ORTMODULE/FallbackToPytorch*
- **Description**: By default, if `ORTModule` fails to run the model using ONNX Runtime backend, it will fallback to use PyTorch to continue the training. At some point developers are optimizing the models and doing benchmarking, we want explicitly let ORT backend to run the model. The way we disable the retry:
- **Description**: Configure `ORTModule` log level. Defaults to LogLevel.WARNING, can be set one of "VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". The environment variable takes precedence if DebugOptions also sets log_level.
#### ORTMODULE_SAVE_ONNX_PATH
- **Feature Area**: *ORTMODULE/DebugOptions*
- **Description**: Configure `ORTModule` to save onnx models. Defaults to False.
The output directory of the onnx models by default is set to the current working directory. To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be set to the destination directory path.
- **Description**: By default `ORTModule` will fail with exception when handling PythonOp export for some `'autograd.Function'`s (One example is torch CheckpointFunction). Set
this env variable to be `1` to explicitly allow it.
```bash
export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1
```
> Take the example of torch.utils.checkpoint.CheckpointFunction, if it is exported as PythonOp, the checkpointed computation may be computed by PyTorch, not ORT. This situation is especially important for big models such as GPT-2 where every few layers are wrapped to do re-computation, large number of computations are done by PyTorch. Currently a failure is reported to notify users it is possible `ORTModule` has less opportunities to optimize further.
> On the other hand, if the wrapped computation graph is small, it is reasonable to allow it.
> Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it.
- **Description**: By default, all torch.autograd.Function classes will be exported to ORT PythonOp. There are some cases where you might consider disable it. For example, if you confirmed those torch.autograd.Function classes defined computations that could be inline exported by PyTorch, and it is safe to use the inline exported ONNX graph to train, then you can disable it, as a result, ORT has more opportunities to optimize more.
- **Description**: By default, this is empty. When user model's setup depends on libraries who might define multiple torch.autograd.Function classes of same name, though their python import paths (e.g. 'namespace') are different, while due to limitation of PyTorch exporter (https://github.com/microsoft/onnx-converters-private/issues/115), ORT backend cannot infer which one to call. So an exception will be thrown for this case.
Before full qualified name can be got from exporter, this environment variables can be used to specify which torch.autograd.Function classes can be ignored. An example as below, be noted, full qualified name is needed here. If there are multiple classes to be ignored, use comma as the separator.
Q: *The model training hits OOM, even with minimum required batch size?*
Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for how to leverage ORT's recomputation techniques.
## 3. Use `FusedAdam` to Accelerate Parameter Update
Parameter update is done by optimizers (for example AdamW) with many elementwise operations. `FusedAdam` launches the elementwise update kernels with multi-tensor apply, allowing batches of gradients applied to corresponding parameters for each time kernel launch.
Here is a sample switch from torch `AdamW` optimizer to `FusedAdam`.
```diff
model = build_model()
- optimizer = AdamW(model.parameters(), lr=1)
+ from onnxruntime.training.optim import FusedAdam
+ optimizer = FusedAdam(model.parameters(), lr=1)
```
Check [FusedAdam implementation](../orttraining/orttraining/python/training/optim/fused_adam.py) for more details.
## 4. Use `FP16_Optimizer` to Complement DeepSpeed/APEX
If user models utilize DeepSpeed or Apex libraries, ORT's `FP16_Optimizer` can be used to complement some inefficiencies introduced by them.
Use `FP16_Optimizer` with DeepSpeed ZeRO Optimizer:
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
+ optimizer = FP16_Optimizer(optimizer)
```
## 6. One More Thing - `LoadBalancingDistributedBatchSampler`
`LoadBalancingDistributedBatchSampler` balances the data load across workers based on the sample's complexity.
This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from **straggler problem**. In such scenarios, the complexity function could be defined to return the length of the input sample sequence. The usage is similar to `torch.utils.data.DistributedSampler`, where each process loads a subset of the original dataset that is exclusive to it.
A sample shown below:
```python
from onnxruntime.training.utils.data import LoadBalancingDistributedSampler, \