### Tune logging experience a bit
After last time we update the ORTModule log experience, we found few
issues:
1. `INFO` level output too many things, including PyTorch exporter
verbose logs (tracing graphs) on every ranks. On this level, we only
want to
- Output a little bit more information to Users than `WARNING` level,
for example the memory recomputation recommendations or other
not-fully-ready features.
- Output a little bit more information for a quick diagnostic, collected
on rank-0 only.
2. ONNX Runtime logging filter during graph build, session init
sometimes will hide the issues (for example segement fault), there is no
useful information in `WARNING`/`INFO` for users to report to us. This
is not good!
3. Some of our devs like using `pdb` to debug Python code, but if we add
`import pdb; pdb.set_trace()` in models' code might hang when they use
`INFO` or `WARNING`, where exporter happens and all output got
redirected due to log filtering. The only workaround is to switch to
VERBOSE, which output toooooooooooo many logs.
The corresponding changes proposed here are:
1. For `INFO` logging,
- We only logs rank-0.
- We restricted the ORT backend logging level to be WARNING in this
case, because ORT backend code output way too many logs that should be
under verbose, while we cannot guarantee we can get them cleaned up
immediately once they are added.
- We output the PyTorch exporter verbose log (including tracing graph),
which is useful for a quick diagnostic when an issue happens.
2. Remove all logging filtering on ORT backend, then the segment fault
issue details will not be hidden once it happens again.
3. Introduced a `DEVINFO` logging,
- Log logs on all ranks
- Log ORT backend logging level INFO
- PyTorch exporter logging filtering are all turned OFF (to unblock the
pdb debugging).
4. Currently, to use Memory Optimizer, need use DEVINFO (which will
output ORT backend INFO log). So update memory optimizer document to
reflect this. https://github.com/microsoft/onnxruntime/pull/17481 will
update the requirement back to INFO for show memory optimization infos.
You can check
https://github.com/microsoft/onnxruntime/blob/pengwa/devinfo_level/docs/ORTModule_Training_Guidelines.md#log-level-explanations
for a better view of different log levels.
This PR also extract some changes from a bigger one
https://github.com/microsoft/onnxruntime/pull/17481, to reduce its
complexity for review.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
---------
Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
4.8 KiB
Memory Optimizer for ONNX Runtime Training
Introduction
ONNX Runtime Training provides a capability trading node/subgraph re-computations for better memory efficiency. Specifically, a list of re-computable operators is pre-defined, with which memory optimizer graph transformer will iterate the graph to find all re-computable subgraph candidates.
When training with ORTModule, by default, the graph transformer will scan the execution graph to find all eligible subgraphs to recompute, along with sizes that can be saved. Users can pick up some of the subgraphs to enable by environment variables.
When memory optimizer can help?
Classical scenarios include:
-
ORTModuleruns a model with batch size B (for example 2^N), the memory bandwidth and compute are not fully saturated, while it hits OOM to run a bigger batch size (for example 2^(N+1)). -
For big models,
ORTModulefails to run the minimum allowed batch size, so performance can be compromised for a successful run.
Not all models and recipes need this optimizer technique. Imagine if your training recipe uses a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6.
Quick trial
- Make sure ONNX Runtime training wheel is installed and correctly configured.
- Integrate models using
ORTModule, be noted log_level should be equal to or lower than DEVINFO.ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.DEVINFO))
- Run the training as usual and redirect all outputs into the log file; then stop it after training a few steps.
- Check the logging file, and search "Summary", you could find something like this:
MemoryOptimizer Summary: User config: ================================= ########Recompute######## Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ OptimizationType: Disabled Patterns: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 -------------------------------- Subgraph: FastGelu+ OptimizationType: Disabled Patterns: PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 ================================= ########RecomputeWithCompromise######## Subgraph: Cast+Where+Softmax+ OptimizationType: Disabled Patterns: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 -------------------------------- ================================= - As shown above, 'Subgraph' shows 1) a string representative for a re-computable subgraph; and 2) current status of memory optimization. All are disabled for recompute in this case.
- Set environment variable
ORTMODULE_MEMORY_OPT_CONFIGto enable some of the subgraph to do recompute. In below example, 12 FastGelu related subgraphs are allowed to recompute.FastGelu+is the subgraph string representative;1in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled);12means the initial 12 subgraph occurrences will be recomputed, all others are left as it is, filling-1will make all occurrences be recomputed.export ORTMODULE_MEMORY_OPT_CONFIG="FastGelu+:1:12" - Then run the training again, you will see logs like this:
MemoryOptimizer Summary: User config: **FastGelu+:1:12** ================================= ########Recompute######## Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ OptimizationType: Disabled Patterns: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 -------------------------------- Subgraph: FastGelu+ OptimizationType: **Recompute (requested_count=12, actual applied_count=12)** Patterns: PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 ================================= ########RecomputeWithCompromise######## Subgraph: Cast+Where+Softmax+ OptimizationType: Disabled Patterns: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 -------------------------------- ================================= - You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well.
Compromised Recompute
If you check the above logs, there is a separate section called "RecomputeWithCompromise". Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it.
Notes
The feature is in experimental stage, we will tune and refine it according to real use cases.