onnxruntime/docs
pengwa a3e7da60e7
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute

This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.

When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```

[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>:  User config:
[1,0]<stderr>:  Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>:  =================================
[1,0]<stderr>:  Subgraph: BitmaskDropout+
[1,0]<stderr>:          AlleviationType: Disabled
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 1,024 x   Frequency:1
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: BiasGelu+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x  Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>:          AlleviationType: Disabled
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:labels_dim0 x      Frequency:1
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>:          AlleviationType: Disabled
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x    Frequency:23
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x    Frequency:1
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Mul+Add+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x         Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>:          AlleviationType: Disabled
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x     Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Mul+Sub+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x         Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Cast+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:1,024 x 1,024 x    Frequency:97
[1,0]<stderr>:                  PatternShape:3 x 1,024 x        Frequency:1
[1,0]<stderr>:                  PatternShape:8 x 64 x   Frequency:24
[1,0]<stderr>:                  PatternShape:1,024 x 4,096 x    Frequency:24
[1,0]<stderr>:                  PatternShape:4,096 x    Frequency:24
[1,0]<stderr>:                  PatternShape:4,096 x 1,024 x    Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: FusedMatMul+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x  Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  =================================
```


"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.

**Baseline**

On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.


![image](https://user-images.githubusercontent.com/10530022/188320941-13dde5e7-c32b-4399-a64b-6803fbb9dcda.png)

**With this PR**

Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).


![image](https://user-images.githubusercontent.com/10530022/188321081-f64811bf-9637-4873-8095-349de8d498cc.png)


**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
2022-11-03 13:49:41 +08:00
..
c_cxx Document C/C++ API documentation version info conventions. (#10396) 2022-01-27 10:20:13 -08:00
execution_providers/images Remove docs that have been migrated to https://onnxruntime.ai/docs (#6225) 2021-02-05 18:09:27 -08:00
images API Documentation (#8948) 2021-09-09 22:04:51 -07:00
python Bumping up version number to 1.14.0 on main branch (#13401) 2022-10-21 19:16:44 -04:00
ABI_Dev_Notes.md skip windows GPU check if changes only in doc (#13248) 2022-10-11 13:51:44 +08:00
Android_testing.md Removed BUILD.md from master as source now lives in gh-pages (#6709) 2021-02-19 11:34:21 -08:00
C_API_Guidelines.md Replace 'master' branch ref to 'main' in the code (#12547) 2022-08-22 10:48:12 -07:00
cmake_guideline.md fix some typo in docs (#13212) 2022-10-07 15:58:18 -07:00
Coding_Conventions_and_Standards.md Fixed a minor typo (#13194) 2022-10-05 12:10:14 -07:00
ContribOperators.md QuickGelu Fusion (#12417) 2022-10-28 18:12:07 +08:00
FAQ.md Fix typo enviroment => environment (#13195) 2022-10-03 17:02:26 -07:00
How_To_Update_ONNX_Dev_Notes.md Update script to find optimizers that potentially need supported opset updates (#12330) 2022-08-04 07:37:27 +10:00
Memory_Optimizer.md Trade subgraph recompute for memory (#12852) 2022-11-03 13:49:41 +08:00
Model_Test.md Renaming MKL-DNN as DNNL (#2515) 2019-12-03 07:34:23 -08:00
NotesOnThreading.md Replace 'master' branch ref to 'main' in the code (#12547) 2022-08-22 10:48:12 -07:00
ONNX_Runtime_Server_Usage.md Update docs/ONNX_Runtime_Server_Usage.md (#7818) 2021-05-26 16:17:20 -07:00
onnxruntime_dependencies.dot Update dependencies graph 2020-04-17 07:38:45 -07:00
onnxruntime_dependencies.png Update dependencies graph 2020-04-17 07:38:45 -07:00
onnxruntime_extensions.md replace 'master' branch ref to 'main' for onnx repo (#12678) 2022-08-30 13:41:42 -07:00
OperatorKernels.md QuickGelu Fusion (#12417) 2022-10-28 18:12:07 +08:00
ORT_Format_Update_in_1.13.md Update kernel matching logic: decouple from op schemas and remove kernel def hashes (#12791) 2022-09-20 14:24:59 -07:00
ORTMobilePackageOperatorTypeSupport.md Replace 'master' branch ref to 'main' in the code (#12547) 2022-08-22 10:48:12 -07:00
PR_Guidelines.md Add guidelines for writing a good PR. (#3830) 2020-05-05 16:28:21 -07:00
Privacy.md [C# and Python APIs] Expose knobs to enable/disable platform telemetry collection (#5481) 2020-10-21 10:32:13 -07:00
Python_Dev_Notes.md Changes related to the release binaries requiring Visual C++ 2019 runtime (#3871) 2020-05-12 17:07:06 -07:00
Reduced_Operator_Kernel_build.md replace 'master' branch ref to 'main' for onnx repo (#12678) 2022-08-30 13:41:42 -07:00
ReleaseManagement.md Updated TPN for OpenMPI and cleanup (#3932) 2020-05-14 11:42:44 -07:00
Roadmap.md Replace 'master' branch ref to 'main' in the code (#12547) 2022-08-22 10:48:12 -07:00
Server.md Update documentation for contributing a PR and add deprecation notices for PyOp and ORT server. (#6172) 2020-12-18 02:00:42 -08:00
TVM_EP.md [C#][TVM EP] Fix issues related to using TVM EP in C# front-end (#12958) 2022-09-16 16:04:59 +02:00
Versioning.md replace 'master' branch ref to 'main' for onnx repo (#12678) 2022-08-30 13:41:42 -07:00
WinML_principles.md Replace 'master' branch ref to 'main' in the code (#12547) 2022-08-22 10:48:12 -07:00