mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
8 commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
8a98874e7e
|
Flash attention recompute (#20603)
### Flash attn recompute 1. Allow PythonOp(FlashAttn) can be recomputed correctly. |
||
|
|
280b2634c5
|
Prompt layer-wise recompute when applicable (#20126)
### Prompt layer-wise when applicable Give explicit prompts in export failures to users to enable layer-wise memory optimization if we found the checkpoint function is used. - Using checkpoint function is a strong indicator that the model is too large to fit in GPU memory. - If we don't override the checkpoint function here, mostly ONNX export will be failed. 1. For old version PyTorch, when handling gradient checkpoint feature, we just throw an exception. 2. For new version PyTorch, an export failure happens. - But both failures did not give users explicitly "HOW" to mitigate. This PR did that. ``  ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
d102569755
|
Fix seed for recomputed Dropout (#19715)
### Fix seed for recomputed Dropout If Dropout node is recomputed in the backward, we should make sure its execution is same as the run in the forward. If we don't set seed attribute, then this cannot be guaranteed. Add ` export ORTMODULE_MEMORY_OPT_LEVEL=2` to enabled per layer recompute with compromised recomputable subgraphs. |
||
|
|
ccf3b2054b
|
Allow layer-wise recompute (#18566)
### Allow layer-wise recompute Early, we need users/developers to specify the subgraphs to recompute, now we introduced a more user-friendly way to enable recompute for all detected stashed activation recomputation subgraphs. This scarifies getting the best configs while makes it easier to support user requirements when they switches from PyTorch per-layer gradient checkpoint to ORTModule. `ORTMODULE_MEMORY_OPT_LEVEL` is introduced to control the usage, by default, it is 0, e.g. `USER_SPECIFIED`, all subgraphs definedin `ORTMODULE_MEMORY_OPT_CONFIG` will be recomputed. So this is compatible to existing recompute usage in ORTModule integrated models. Using `ORTMODULE_MEMORY_OPT_LEVEL=1`, we will enable all recompute plans detected, so those configs in `ORTMODULE_MEMORY_OPT_CONFIG` will not be respected any more. Add Unit Tests using 3 layer blooms. https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md |
||
|
|
43a5147e01
|
Memory optimization refactor and refinement (#17481)
### Memory optimization refactor and refinement Currently memory optimizer runs graph transformations and print recompute opportunities in INFO level, while ORT backend has many many INFO level logs making users hard to find those information. So we are looking for a Python binding API to retrieve the memory optimization opportunities instead of depending on the MemoryOptimizer's default logging. Then we can print ORTModule feature statistics using this information. Also, with such an API, we can create an ORT session created, where allocation plan is done, the analysis will consider buffer reuse as well. This can void giving some recomputation subgraphs that are reusing other subgraphs' output buffers. Check https://github.com/microsoft/onnxruntime/blob/pengwa/add_devinfo_level/docs/Memory_Optimizer.md for the new flow using `MemoryOptimizer`. This pull requests made following refactoring: 1. Print the log in ORTModule Python script, along with ORTModule feature enabling stats. This is implemented by exposing an API `get_serialized_ortmodule_memory_stat` to retrieve the memory optimization opportunities. 2. We are analyzing memory optimization opportunities considering ORT memory planning. This is done by firstly creating the execution graph without enabling MemoryOptimizer, then we call `execution_agent.get_serialized_ortmodule_memory_stat` which internally will consider the session memory allocation planner when analyzing memory optimization opportunity. As a direct result, the memory optimization opportunities can show those stashed activations that are reusing other buffers. 3. Move recompute analysis logic from memory_optimizer.h/cc to recompute_analysis.h/cc. 4. Abstract optimization strategies for their own implementation. This will make introducing new strategies (for example compression and decompression ) easier. New logging matrix (INFO Level), in WARNING level, the details will NOT show. ``` 2023-09-13 13:25:09,249 orttraining.rank-0 [WARNING] - ***** ONNX Runtime Training (ORTModule) is accelerating your model ***** ORTModule is enabled with following features ON/OFF for [training] mode: ATen Executor : ON : Dispatch ATen operators to ORT's ATen executor Cast Propagation : ON : Level 1 enabled Custom Function : ON : Support custom torch.autograd.Function export and execution Memory Optimizer : ON : RecomputeConfig: Reshape+Where+BiasSoftmax+:1:-1,Cast+:1:-1, ProbeLevel: 1, available configs: Config Freq Saving(B) Saving Symbolic(Bytes) - Plan 1 : ON : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - Plan 2 : ON : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0) - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 Compute Optimizer : ON : Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0 - FLOPReduction : ON : Reduce FLOPs by upstreaming shrinking-sized ops Auto Fallback : ON : Fallback to PyTorch when encountering unsupported ops TritonOp Enabled : OFF : ORT will switch to Triton for executing some ops to further accelerate training. ZeRO Stage3 Support : OFF : Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0 Total ORT initialization overhead is 10.73s where export takes 8.39s. Other overhead details: graph builder init takes 0.06s, runtime detection takes 0.01s, graph building takes 0.31s, session creation takes 1.96s Versions: ONNX Runtime - 1.16.0+cu118, ONNX - 1.11.0 Note 1: use comma to enable multiple plans at the same time. export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,... Note 2: saving is calculated based on the 1st batch symbolic dim values: inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024, ************************************************************************ ``` If DEVINFO level is enabled, then more details about the memory optimizations are printed. ``` MemoryInsight Summary - User config: BiasGelu+:1:-1,Cast+:2:-1 ========================================================================================================================================== |Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |3 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 | | | Stashed Activations: | | | - ReuseFreq : Output 0(3), | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 32 x 240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+:1:-1 | | | Stashed Activations: | | | - ReuseFreq : Output 0(2), | | | - Output 0 : [ x 2560 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Cast+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Where+BiasSoftmax+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+BiasSoftmax+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph BiasGelu+ | | | Status : Enabled, requested count=-1, actual applied count=2 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+Add+FusedMatMul+Add+Add+Add+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 2560 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Where+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Cast+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | | | | | |>>Option 2 : RecomputeWithCompromise subgraph Cast+ | | | Status : Enabled, requested count=-1, actual applied count=1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 50% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph BiasSoftmax+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=BiasSoftmax+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph BiasGelu+ | | | Status : Enabled, requested count=-1, actual applied count=1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Add+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Add+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 2560 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ========================================================================================================================================== Note: use comma as a separator for enabling more than one subgraphs. ************************************************************************ ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
2151c79bf1
|
Tune ORTModule logging experience a bit (#18298)
### Tune logging experience a bit
After last time we update the ORTModule log experience, we found few
issues:
1. `INFO` level output too many things, including PyTorch exporter
verbose logs (tracing graphs) on every ranks. On this level, we only
want to
- Output a little bit more information to Users than `WARNING` level,
for example the memory recomputation recommendations or other
not-fully-ready features.
- Output a little bit more information for a quick diagnostic, collected
on rank-0 only.
2. ONNX Runtime logging filter during graph build, session init
sometimes will hide the issues (for example segement fault), there is no
useful information in `WARNING`/`INFO` for users to report to us. This
is not good!
3. Some of our devs like using `pdb` to debug Python code, but if we add
`import pdb; pdb.set_trace()` in models' code might hang when they use
`INFO` or `WARNING`, where exporter happens and all output got
redirected due to log filtering. The only workaround is to switch to
VERBOSE, which output toooooooooooo many logs.
The corresponding changes proposed here are:
1. For `INFO` logging,
- We only logs rank-0.
- We restricted the ORT backend logging level to be WARNING in this
case, because ORT backend code output way too many logs that should be
under verbose, while we cannot guarantee we can get them cleaned up
immediately once they are added.
- We output the PyTorch exporter verbose log (including tracing graph),
which is useful for a quick diagnostic when an issue happens.
2. Remove all logging filtering on ORT backend, then the segment fault
issue details will not be hidden once it happens again.
3. Introduced a `DEVINFO` logging,
- Log logs on all ranks
- Log ORT backend logging level INFO
- PyTorch exporter logging filtering are all turned OFF (to unblock the
pdb debugging).
4. Currently, to use Memory Optimizer, need use DEVINFO (which will
output ORT backend INFO log). So update memory optimizer document to
reflect this. https://github.com/microsoft/onnxruntime/pull/17481 will
update the requirement back to INFO for show memory optimization infos.
You can check
https://github.com/microsoft/onnxruntime/blob/pengwa/devinfo_level/docs/ORTModule_Training_Guidelines.md#log-level-explanations
for a better view of different log levels.
This PR also extract some changes from a bigger one
https://github.com/microsoft/onnxruntime/pull/17481, to reduce its
complexity for review.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
---------
Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
|
||
|
|
ab9ac2acc4
|
Add guidelines for ORTModule (#13553)
### Add guidelines for ORTModule As title. Feel free to let me know if I missed something. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
a3e7da60e7
|
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute This PR adds an optional capability trading additional re-computation for better memory efficiency. Specifically, a pre-defined operator list used to iterate the Graph to find some subgraphs for recompute, to reduce some stashed activations whose lifetime across forward and backward pass. When training with ORTModule, by default, the graph transformer will scan the execution graph to find all eligible subgraph to recompute, along with sizes that can save. An example looks like below. If we want to enable some of them to recompute, we can define env variable this way: `export ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"` ``` [1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary] [1,0]<stderr>:MemoryAlleviation Summary: [1,0]<stderr>: User config: [1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1 [1,0]<stderr>: ================================= [1,0]<stderr>: Subgraph: BitmaskDropout+ [1,0]<stderr>: AlleviationType: Disabled [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: BiasGelu+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+ [1,0]<stderr>: AlleviationType: Disabled [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+ [1,0]<stderr>: AlleviationType: Disabled [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Mul+Add+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+ [1,0]<stderr>: AlleviationType: Disabled [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Mul+Sub+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: Cast+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97 [1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1 [1,0]<stderr>: PatternShape:8 x 64 x Frequency:24 [1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24 [1,0]<stderr>: PatternShape:4,096 x Frequency:24 [1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: Subgraph: FusedMatMul+ [1,0]<stderr>: AlleviationType: Recompute [1,0]<stderr>: Patterns: [1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24 [1,0]<stderr>: -------------------------------- [1,0]<stderr>: ================================= ``` "Type config:" whether recompute is enabled by users. 0 - disable, 1- enable. "Subgraph" means what kind of subgraph will be recomputed, in this case, it is a single node "Gelu", and it will be "Recompute". "Shape && Frequency" means, for this recompute, one tensor of size (batch size, 500) will be saved because it will be recomputed. **Baseline** On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100 GPUs. With latest main branch, we can run batch size 16, and the maximum batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB memory is used during training. The SamplesPerSec=479.2543353561354.  **With this PR** Gelu is recomputed for saving memory peak, batch size 32 can be run. The 97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X** of baseline).  **Motivation and Context** - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. |