mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
3 commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
2151c79bf1
|
Tune ORTModule logging experience a bit (#18298)
### Tune logging experience a bit
After last time we update the ORTModule log experience, we found few
issues:
1. `INFO` level output too many things, including PyTorch exporter
verbose logs (tracing graphs) on every ranks. On this level, we only
want to
- Output a little bit more information to Users than `WARNING` level,
for example the memory recomputation recommendations or other
not-fully-ready features.
- Output a little bit more information for a quick diagnostic, collected
on rank-0 only.
2. ONNX Runtime logging filter during graph build, session init
sometimes will hide the issues (for example segement fault), there is no
useful information in `WARNING`/`INFO` for users to report to us. This
is not good!
3. Some of our devs like using `pdb` to debug Python code, but if we add
`import pdb; pdb.set_trace()` in models' code might hang when they use
`INFO` or `WARNING`, where exporter happens and all output got
redirected due to log filtering. The only workaround is to switch to
VERBOSE, which output toooooooooooo many logs.
The corresponding changes proposed here are:
1. For `INFO` logging,
- We only logs rank-0.
- We restricted the ORT backend logging level to be WARNING in this
case, because ORT backend code output way too many logs that should be
under verbose, while we cannot guarantee we can get them cleaned up
immediately once they are added.
- We output the PyTorch exporter verbose log (including tracing graph),
which is useful for a quick diagnostic when an issue happens.
2. Remove all logging filtering on ORT backend, then the segment fault
issue details will not be hidden once it happens again.
3. Introduced a `DEVINFO` logging,
- Log logs on all ranks
- Log ORT backend logging level INFO
- PyTorch exporter logging filtering are all turned OFF (to unblock the
pdb debugging).
4. Currently, to use Memory Optimizer, need use DEVINFO (which will
output ORT backend INFO log). So update memory optimizer document to
reflect this. https://github.com/microsoft/onnxruntime/pull/17481 will
update the requirement back to INFO for show memory optimization infos.
You can check
https://github.com/microsoft/onnxruntime/blob/pengwa/devinfo_level/docs/ORTModule_Training_Guidelines.md#log-level-explanations
for a better view of different log levels.
This PR also extract some changes from a bigger one
https://github.com/microsoft/onnxruntime/pull/17481, to reduce its
complexity for review.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
---------
Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
|
||
|
|
ab9ac2acc4
|
Add guidelines for ORTModule (#13553)
### Add guidelines for ORTModule As title. Feel free to let me know if I missed something. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
a3e7da60e7
|
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute This PR adds an optional capability trading additional re-computation for better memory efficiency. Specifically, a pre-defined operator list used to iterate the Graph to find some subgraphs for recompute, to reduce some stashed activations whose lifetime across forward and backward pass. When training with ORTModule, by default, the graph transformer will scan the execution graph to find all eligible subgraph to recompute, along with sizes that can save. An example looks like below. If we want to enable some of them to recompute, we can define env variable this way: `export ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"` ``` [1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary] [1,0]<stderr>:MemoryAlleviation Summary: [1,0]<stderr>: User config: [1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1 [1,0]<stderr>: ================================= [1,0]<stderr>: Subgraph: BitmaskDropout+ [1,0]<stderr>: AlleviationType: Disabled [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: BiasGelu+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+ [1,0]<stderr>: AlleviationType: Disabled [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+ [1,0]<stderr>: AlleviationType: Disabled [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Mul+Add+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+ [1,0]<stderr>: AlleviationType: Disabled [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Mul+Sub+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Cast+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97 [1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1 [1,0]<stderr>: PatternShape:8 x 64 x Frequency:24 [1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24 [1,0]<stderr>: PatternShape:4,096 x Frequency:24 [1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: FusedMatMul+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: ================================= ``` "Type config:" whether recompute is enabled by users. 0 - disable, 1- enable. "Subgraph" means what kind of subgraph will be recomputed, in this case, it is a single node "Gelu", and it will be "Recompute". "Shape && Frequency" means, for this recompute, one tensor of size (batch size, 500) will be saved because it will be recomputed. **Baseline** On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100 GPUs. With latest main branch, we can run batch size 16, and the maximum batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB memory is used during training. The SamplesPerSec=479.2543353561354.  **With this PR** Gelu is recomputed for saving memory peak, batch size 32 can be run. The 97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X** of baseline).  **Motivation and Context** - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. |