Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute
This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.
When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```
[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>: User config:
[1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>: =================================
[1,0]<stderr>: Subgraph: BitmaskDropout+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: BiasGelu+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Add+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Sub+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97
[1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1
[1,0]<stderr>: PatternShape:8 x 64 x Frequency:24
[1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: =================================
```
"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.
**Baseline**
On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.

**With this PR**
Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
2022-11-03 05:49:41 +00:00
# Memory Optimizer for ONNX Runtime Training
## Introduction
2022-11-04 11:42:10 +00:00
ONNX Runtime Training provides a capability trading node/subgraph re-computations for better memory efficiency.
Specifically, a list of re-computable operators is pre-defined, with which memory optimizer graph transformer will iterate the graph to find all re-computable subgraph candidates.
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute
This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.
When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```
[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>: User config:
[1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>: =================================
[1,0]<stderr>: Subgraph: BitmaskDropout+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: BiasGelu+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Add+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Sub+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97
[1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1
[1,0]<stderr>: PatternShape:8 x 64 x Frequency:24
[1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: =================================
```
"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.
**Baseline**
On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.

**With this PR**
Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
2022-11-03 05:49:41 +00:00
2022-11-04 11:42:10 +00:00
When training with `ORTModule` , by default, the graph transformer will scan the execution graph to find all eligible subgraphs to recompute, along with sizes that can be saved. Users can pick up some of the subgraphs to enable by environment variables.
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute
This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.
When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```
[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>: User config:
[1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>: =================================
[1,0]<stderr>: Subgraph: BitmaskDropout+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: BiasGelu+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Add+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Sub+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97
[1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1
[1,0]<stderr>: PatternShape:8 x 64 x Frequency:24
[1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: =================================
```
"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.
**Baseline**
On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.

**With this PR**
Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
2022-11-03 05:49:41 +00:00
## When memory optimizer can help?
Classical scenarios include:
2022-11-04 11:42:10 +00:00
- `ORTModule` runs a model with batch size B (for example 2^N), the memory bandwidth and compute are not fully saturated, while it hits OOM to run a bigger batch size (for example 2^(N+1)).
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute
This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.
When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```
[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>: User config:
[1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>: =================================
[1,0]<stderr>: Subgraph: BitmaskDropout+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: BiasGelu+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Add+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Sub+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97
[1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1
[1,0]<stderr>: PatternShape:8 x 64 x Frequency:24
[1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: =================================
```
"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.
**Baseline**
On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.

**With this PR**
Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
2022-11-03 05:49:41 +00:00
2022-11-04 11:42:10 +00:00
- For big models, `ORTModule` fails to run the minimum allowed batch size, so performance can be compromised for a successful run.
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute
This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.
When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```
[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>: User config:
[1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>: =================================
[1,0]<stderr>: Subgraph: BitmaskDropout+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: BiasGelu+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Add+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Sub+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97
[1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1
[1,0]<stderr>: PatternShape:8 x 64 x Frequency:24
[1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: =================================
```
"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.
**Baseline**
On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.

**With this PR**
Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
2022-11-03 05:49:41 +00:00
2022-11-04 11:42:10 +00:00
Not all models and recipes need this optimizer technique. Imagine if your training recipe uses a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6.
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute
This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.
When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```
[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>: User config:
[1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>: =================================
[1,0]<stderr>: Subgraph: BitmaskDropout+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: BiasGelu+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Add+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Sub+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97
[1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1
[1,0]<stderr>: PatternShape:8 x 64 x Frequency:24
[1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: =================================
```
"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.
**Baseline**
On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.

**With this PR**
Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
2022-11-03 05:49:41 +00:00
## Quick trial
1. Make sure ONNX Runtime training wheel is installed and correctly configured.
2022-11-04 11:42:10 +00:00
2. Integrate models using `ORTModule` , be noted log_level should be equal or lower than INFO.
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute
This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.
When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```
[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>: User config:
[1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>: =================================
[1,0]<stderr>: Subgraph: BitmaskDropout+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: BiasGelu+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Add+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Sub+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97
[1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1
[1,0]<stderr>: PatternShape:8 x 64 x Frequency:24
[1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: =================================
```
"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.
**Baseline**
On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.

**With this PR**
Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
2022-11-03 05:49:41 +00:00
> ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.INFO))
3. Run the training as usual and redirect all outputs into log file; then stop it after training few steps.
4. Check the logging file, search "Summary", you could possibly find something like this:
```
MemoryOptimizer Summary:
User config:
=================================
########Recompute########
Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
--------------------------------
Subgraph: FastGelu+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24
=================================
########RecomputeWithCompromise########
Subgraph: Cast+Where+Softmax+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24
--------------------------------
=================================
```
2022-11-04 11:42:10 +00:00
5. As shown above, 'Subgraph' shows 1) a string representative for a re-computable subgraph; and 2) current status of memory optimization. All are disabled for recompute in this case.
6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, 12 FastGelu related subgraphs are allowed to recompute.
Trade subgraph recompute for memory (#12852)
**Description**: Subgraph-level recompute
This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.
When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```
[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>: User config:
[1,0]<stderr>: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>: =================================
[1,0]<stderr>: Subgraph: BitmaskDropout+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 1,024 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: BiasGelu+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:labels_dim0 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Add+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>: AlleviationType: Disabled
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Mul+Sub+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: Cast+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:1,024 x 1,024 x Frequency:97
[1,0]<stderr>: PatternShape:3 x 1,024 x Frequency:1
[1,0]<stderr>: PatternShape:8 x 64 x Frequency:24
[1,0]<stderr>: PatternShape:1,024 x 4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x Frequency:24
[1,0]<stderr>: PatternShape:4,096 x 1,024 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: Subgraph: FusedMatMul+
[1,0]<stderr>: AlleviationType: Recompute
[1,0]<stderr>: Patterns:
[1,0]<stderr>: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24
[1,0]<stderr>: --------------------------------
[1,0]<stderr>: =================================
```
"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.
**Baseline**
On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.

**With this PR**
Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
2022-11-03 05:49:41 +00:00
`FastGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `12` means the initial 12 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed.
```
export ORTMODULE_MEMORY_OPT_CONFIG="FastGelu+:1:12"
```
7. Then run the training again, you will see logs like this:
```
MemoryOptimizer Summary:
User config:
**FastGelu+:1:12**
=================================
########Recompute########
Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
--------------------------------
Subgraph: FastGelu+
OptimizationType: **Recompute (requested_count=12, actual applied_count=12)**
Patterns:
PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24
=================================
########RecomputeWithCompromise########
Subgraph: Cast+Where+Softmax+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24
--------------------------------
=================================
```
8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well.
## Compromised Recompute
If you check the above logs, there is a separate section called "RecomputeWithCompromise". Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it.
## Notes
The feature is in experimental stage, we will tune and refine it according to real use cases.