Commit graph

21 commits

Author SHA1 Message Date
pengwa
88336ffa92
Fix typos - 1st Wave (#21278)
### Description

There are so many typos reported by the review dog, [Optional Lint]
actions (example:
https://github.com/microsoft/onnxruntime/actions/runs/9864564489/job/27239732367),
this PR is to fix some of them.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
2024-07-11 13:35:08 +08:00
pengwa
8a98874e7e
Flash attention recompute (#20603)
### Flash attn recompute

1. Allow PythonOp(FlashAttn) can be recomputed correctly.
45879ff5c2
2. Use JSON to pass the selected-to-recompute subgraphs.
3c374da678

#### Better Memory Efficiency 

Customer model can run both PyTorch SPDA and Flash Attn, this PR make it
possible to let the Flash Attn path work with ORTModule layerwise
recompute. The peak drop from 45.xGB to 32.xGB if we only compare the
layers (not including other pieces, BTW there are few more optimization
targeting other pieces as well later).

#### Better Perf

Using Flash ATTN bring additionally 16% end to end time reduction, with
highly aligned loss curve.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/bb63894a-f281-49bc-a8e6-ff818439be38)

#### Use JSON File to pass Recompute Plans

To overcome the limitation of max length of the strings defined in
session options.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-05-21 13:38:19 +08:00
guyang3532
cfe830b248
Generalize label input sparsity check and refactor (#20636)
### Description
The InsertGatherBeforeSceLoss optimization is enabled when the density
of label padding less than 90%. We need to check the density of the
label padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the SCE node by iterate graph from the SCE node back to one graph input.
This is hard to be general because there may be complicated pattern
between graph input and SCE node.

This pr check padding density by the direct input of SCE module rather
than the input of graph at the first graph execution when exporting onnx
graph.
And if the density < 90%, insert a flag PythonOp after the SCE node as:
```
           SoftmaxCrossEntropy
		  |
            PythonOp (func_name: FlagAndPrintDensity)   (insert if density < 90%)
		  |
            Following graph
```

When the InsertGatherBeforeSceLoss is invoked, it check if there is the
flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if
it is, remove it and do the padding elimination optimization.

If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input
density each step by the PythonOp (func_name: FlagAndPrintDensity). In
this case the PythonOp will not be removed.
2024-05-10 21:55:43 +08:00
Justin Chu
faea42af95
Bump ruff to 0.3.2 and black to 24 (#19878)
### Motivation and Context

Routing updates
2024-03-13 10:00:32 -07:00
pengwa
d102569755
Fix seed for recomputed Dropout (#19715)
### Fix seed for recomputed Dropout

If Dropout node is recomputed in the backward, we should make sure its
execution is same as the run in the forward.
If we don't set seed attribute, then this cannot be guaranteed. 

Add ` export ORTMODULE_MEMORY_OPT_LEVEL=2` to enabled per layer
recompute with compromised recomputable subgraphs.
2024-03-06 10:06:25 +08:00
guyang3532
cd56ea4a74
enable embedding sparse optimization by default (#19714) 2024-03-05 13:15:30 +08:00
jingyanwangms
319481898c
Give a triton library missing warning instead of silently turn off (#19276)
### Description
When USE_ORTMODULE_TRITON is set to 1 but there's no triton library,
triton function is silently turned off. This adds a warning
2024-02-01 15:25:33 -08:00
pengwa
1150b1f81e
ORTModule memory improvement (#18924)
## Dependency

https://github.com/microsoft/onnxruntime/pull/19007

## ORTModule memory efficient gradient management

Previously I have tried to solve the coarsed-grained gradient
accumulation/update problem in ORTModule with
https://github.com/microsoft/onnxruntime/pull/8979, while that
resolution somehow is not fully validated with DDP or there is user
hooks on the gradient accumulation on torch parameter.

This PR is addressing the problem in the similar approach as PR 8979,
e.g. trigger gradient accumulation once ORT computed the grad, but
instead of use a AccumulateGrad op, this time with a ONNX operator
PythonOp, internally it will call param.backward(grad), which will help
handle all related hooks correctly.


## Design

Check the details from


https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ

## Convergence Validation:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe)

differences are on mostly 0.000x, sometimes 0.00x, which may comes from
the different order gradient apply happens before or after this change
(on deepspeed zero stage 2)


## TODO

Consolidate the logic with Stage3's similar logic.
2024-01-16 08:57:37 +08:00
pengwa
d03e477b90
Fix missing subgraph candidates for recompute (#19077)
### Fix missing subgraph candidates for recompute

For subgraphs for example `MatMul+Transpose+Reshape`, since the ending
node is a Reshape, in ORT, it is reusing input buffers.

Currently, the subgraph detection logic has defect, as a result, those
subgraphs will be missing as recompute candidates.

Also append a few more node types for recompute support. 

TODO: add unit test later. This PR is needed for a customer model now.
2024-01-11 12:50:55 +08:00
pengwa
ccf3b2054b
Allow layer-wise recompute (#18566)
### Allow layer-wise recompute 

Early, we need users/developers to specify the subgraphs to recompute,
now we introduced a more user-friendly way to enable recompute for all
detected stashed activation recomputation subgraphs. This scarifies
getting the best configs while makes it easier to support user
requirements when they switches from PyTorch per-layer gradient
checkpoint to ORTModule.

`ORTMODULE_MEMORY_OPT_LEVEL` is introduced to control the usage, by
default, it is 0, e.g. `USER_SPECIFIED`, all subgraphs definedin
`ORTMODULE_MEMORY_OPT_CONFIG` will be recomputed. So this is compatible
to existing recompute usage in ORTModule integrated models.

Using `ORTMODULE_MEMORY_OPT_LEVEL=1`, we will enable all recompute plans
detected, so those configs in `ORTMODULE_MEMORY_OPT_CONFIG` will not be
respected any more.


Add Unit Tests using 3 layer blooms. 



https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md
2023-12-12 08:44:05 +08:00
pengwa
4bfa84487c
Skip module clone for preparing large model export (#18663)
### Skip module clone for preparing large model export

For LLAMA2 13B, when running with Lora, DeepSpeed stage2 on 8 GPUs . It
failed during preparing outputs which will be used for
torch.onnx.export. The reason, we deep copy all the params including
both big sizes of frozen weights, + a little bit of Lora trainable
weight.

This PR will firstly check whether the GPU memmory is enough for a
cloned module, if not, skip the copy.

Copying the module is to guarantee the fw path run may change the
weight, while this case should be rare. But for now, Not-Able-To-Run is
worse than Runnable-with-A-little-bit-different-initial-weight,
especially for large models.
2023-12-05 12:41:17 -08:00
pengwa
43a5147e01
Memory optimization refactor and refinement (#17481)
### Memory optimization refactor and refinement

Currently memory optimizer runs graph transformations and print
recompute opportunities in INFO level, while ORT backend has many many
INFO level logs making users hard to find those information. So we are
looking for a Python binding API to retrieve the memory optimization
opportunities instead of depending on the MemoryOptimizer's default
logging.
Then we can print ORTModule feature statistics using this information. 
Also, with such an API, we can create an ORT session created, where
allocation plan is done, the analysis will consider buffer reuse as
well. This can void giving some recomputation subgraphs that are reusing
other subgraphs' output buffers.

Check
https://github.com/microsoft/onnxruntime/blob/pengwa/add_devinfo_level/docs/Memory_Optimizer.md
for the new flow using `MemoryOptimizer`.

This pull requests made following refactoring:
1. Print the log in ORTModule Python script, along with ORTModule
feature enabling stats. This is implemented by exposing an API
`get_serialized_ortmodule_memory_stat` to retrieve the memory
optimization opportunities.
2. We are analyzing memory optimization opportunities considering ORT
memory planning. This is done by firstly creating the execution graph
without enabling MemoryOptimizer, then we call
`execution_agent.get_serialized_ortmodule_memory_stat` which internally
will consider the session memory allocation planner when analyzing
memory optimization opportunity. As a direct result, the memory
optimization opportunities can show those stashed activations that are
reusing other buffers.
3. Move recompute analysis logic from memory_optimizer.h/cc to
recompute_analysis.h/cc.
4. Abstract optimization strategies for their own implementation. This
will make introducing new strategies (for example compression and
decompression ) easier.

New logging matrix (INFO Level), in WARNING level, the details will NOT
show.
```
2023-09-13 13:25:09,249 orttraining.rank-0 [WARNING] -
***** ONNX Runtime Training (ORTModule) is accelerating your model *****

ORTModule is enabled with following features ON/OFF for [training] mode:

  ATen Executor         :   ON    :   Dispatch ATen operators to ORT's ATen executor
  Cast Propagation      :   ON    :   Level 1 enabled
  Custom Function       :   ON    :   Support custom torch.autograd.Function export and execution
  Memory Optimizer      :   ON    :   RecomputeConfig: Reshape+Where+BiasSoftmax+:1:-1,Cast+:1:-1, ProbeLevel: 1, available configs:
                                      Config                                                      Freq    Saving(B)       Saving Symbolic(Bytes)
   - Plan 1             :   ON    :   Reshape+Where+BiasSoftmax+:1:-1                             5       671,088,640     640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
   - Plan 2             :   ON    :   Cast+:1:-1                                                  6       402,587,648     inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0)
   - Plan 3             :   OFF   :   Reshape+Where+:1:-1                                         1       134,217,728     128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
   - Plan 4             :   OFF   :   BiasSoftmax+:1:-1                                           1       134,086,656     128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
   - Plan 5             :   OFF   :   BiasGelu+:1:-1                                              6       125,808,640     inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
   - Plan 6             :   OFF   :   FusedMatMul+:1:-1                                           6       125,808,640     inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
   - Plan 7             :   OFF   :   FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1               5       26,214,400      25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1
   - Plan 8             :   OFF   :   Add+:1:-1                                                   1       5,237,760       5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
   - Plan 9             :   OFF   :   Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1         1       4,096           4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
   - Plan 10            :   OFF   :   Cast+:2:-1                                                  1       2,048           2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
  Compute Optimizer     :   ON    :   Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0
   - FLOPReduction      :   ON    :   Reduce FLOPs by upstreaming shrinking-sized ops
  Auto Fallback         :   ON    :   Fallback to PyTorch when encountering unsupported ops
  TritonOp Enabled      :   OFF   :   ORT will switch to Triton for executing some ops to further accelerate training.
  ZeRO Stage3 Support   :   OFF   :   Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0

Total ORT initialization overhead is 10.73s where export takes 8.39s.
Other overhead details:  graph builder init takes 0.06s, runtime detection takes 0.01s, graph building takes 0.31s, session creation takes 1.96s

Versions: ONNX Runtime - 1.16.0+cu118, ONNX - 1.11.0

Note 1: use comma to enable multiple plans at the same time.
  export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Note 2: saving is calculated based on the 1st batch symbolic dim values:
  inputs_input_ids_dim0=1,
  inputs_input_ids_dim1=1024,
  inputs_attention_mask_dim0=1,
  inputs_attention_mask_dim1=1024,
  inputs_labels_dim0=1,
  inputs_labels_dim1=1024,

************************************************************************
```

If DEVINFO level is enabled, then more details about the memory
optimizations are printed.
```

MemoryInsight Summary - User config: BiasGelu+:1:-1,Cast+:2:-1
==========================================================================================================================================
|Freq   | Memory Optimization Opportunities (Clustered by node-level activation patterns)                                                |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|3      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph FusedMatMul+Add+Reshape+                                                                    |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1                         |
|       |  Stashed Activations:                                                                                                          |
|       |   - ReuseFreq :  Output 0(3),                                                                                                  |
|       |   - Output 0  : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 32 x 240 x ], byte/elem: 2, 100% saved                        |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph Reshape+                                                                                    |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+:1:-1                                         |
|       |  Stashed Activations:                                                                                                          |
|       |   - ReuseFreq :  Output 0(2),                                                                                                  |
|       |   - Output 0  : [ x 2560 x ], byte/elem: 2, 100% saved                                                                         |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph FusedMatMul+                                                                                |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1                                     |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved                           |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph Cast+                                                                                       |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1                                            |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved      |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph Reshape+Where+BiasSoftmax+                                                                  |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+BiasSoftmax+:1:-1                       |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved      |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph BiasGelu+                                                                                   |
|       |  Status       : Enabled, requested count=-1, actual applied count=2                                                            |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved                           |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph FusedMatMul+Add+FusedMatMul+Add+Add+Add+                                                    |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1         |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 2560 x ], byte/elem: 2, 100% saved                            |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph Reshape+Where+                                                                              |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+:1:-1                                   |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved      |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph FusedMatMul+                                                                                |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1                                     |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved                       |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph Cast+                                                                                       |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1                                            |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved  |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+                                              |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1   |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved                           |
|       |                                                                                                                                |
|       |>>Option 2     : RecomputeWithCompromise subgraph Cast+                                                                         |
|       |  Status       : Enabled, requested count=-1, actual applied count=1                                                            |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 50% saved                            |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph BiasSoftmax+                                                                                |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=BiasSoftmax+:1:-1                                     |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved  |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph BiasGelu+                                                                                   |
|       |  Status       : Enabled, requested count=-1, actual applied count=1                                                            |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved                       |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1      |For each row options are mutually exclusive, only one of them can be enabled.                                                   |
|       |                                                                                                                                |
|       |>>Option 1     : Recompute subgraph Add+                                                                                        |
|       |  Status       : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Add+:1:-1                                             |
|       |  Stashed Activations:                                                                                                          |
|       |   - Output 0  : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 2560 x ], byte/elem: 2, 100% saved                        |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
==========================================================================================================================================
Note: use comma as a separator for enabling more than one subgraphs.

************************************************************************

```


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-11-23 11:39:00 +08:00
pengwa
2151c79bf1
Tune ORTModule logging experience a bit (#18298)
### Tune logging experience a bit

After last time we update the ORTModule log experience, we found few
issues:
1. `INFO` level output too many things, including PyTorch exporter
verbose logs (tracing graphs) on every ranks. On this level, we only
want to
- Output a little bit more information to Users than `WARNING` level,
for example the memory recomputation recommendations or other
not-fully-ready features.
- Output a little bit more information for a quick diagnostic, collected
on rank-0 only.
2. ONNX Runtime logging filter during graph build, session init
sometimes will hide the issues (for example segement fault), there is no
useful information in `WARNING`/`INFO` for users to report to us. This
is not good!
3. Some of our devs like using `pdb` to debug Python code, but if we add
`import pdb; pdb.set_trace()` in models' code might hang when they use
`INFO` or `WARNING`, where exporter happens and all output got
redirected due to log filtering. The only workaround is to switch to
VERBOSE, which output toooooooooooo many logs.

The corresponding changes proposed here are:
1. For `INFO` logging, 
    - We only logs rank-0. 
- We restricted the ORT backend logging level to be WARNING in this
case, because ORT backend code output way too many logs that should be
under verbose, while we cannot guarantee we can get them cleaned up
immediately once they are added.
- We output the PyTorch exporter verbose log (including tracing graph),
which is useful for a quick diagnostic when an issue happens.
2. Remove all logging filtering on ORT backend, then the segment fault
issue details will not be hidden once it happens again.
 3. Introduced a `DEVINFO` logging,
     - Log logs on all ranks
     - Log ORT backend logging level INFO
- PyTorch exporter logging filtering are all turned OFF (to unblock the
pdb debugging).
4. Currently, to use Memory Optimizer, need use DEVINFO (which will
output ORT backend INFO log). So update memory optimizer document to
reflect this. https://github.com/microsoft/onnxruntime/pull/17481 will
update the requirement back to INFO for show memory optimization infos.

You can check
https://github.com/microsoft/onnxruntime/blob/pengwa/devinfo_level/docs/ORTModule_Training_Guidelines.md#log-level-explanations
for a better view of different log levels.

This PR also extract some changes from a bigger one
https://github.com/microsoft/onnxruntime/pull/17481, to reduce its
complexity for review.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
2023-11-08 17:42:50 +08:00
pengwa
c8e1038eab
Optimize 4bit Qlora training (#18131)
### Optimize 4bit Qlora training

Extent existing `MatmulBnb4bit` to its usage in training scenarios. 

The PR includes following changes:
1. Add special `torch.autograd.Function` export logic for
`bitsandbytes.autograd._functions.MatMul4Bit` that is preferred before
common PythonOp exporter.
2. Add `training_mode` optional attribute for op `MatmulBnb4bit`, which
help skip some inference specific logic in implementation.
3. Add `transB` optional attribute, which is by default be 1; setting it
to be 0 is needed by backward usage.

Changing from `PythonOp` to this `MatmulBnb4bit` brings roughly ~2.9%
throughput gains. The reason is:
`bitsandbytes.autograd._functions.MatMul4Bit` has logic
`ctx.save_for_backward`, which would need an additional copy in
PythonOp, otherwise, the tensor might be released by ORT, while backward
op still references it.

Removing the clones also reduce the peak memory consumptions because
`bitsandbytes.autograd._functions.MatMul4Bit` saved tensors that are not
needed in backward compute.
2023-11-02 09:46:11 -07:00
pengwa
d90afc697b
Introduce ZeROOffloadSubscriber for ORTModule (#17006)
### Introduce ZeROOffloadSubscriber for ORTModule

As part of the work: integrate ORTModule with DeepSpeed stage3, this PR
mainly focus on moving original PyTorch-based (leveraging hooks) param
partition/offload implementation to ORTModule compatible implementation.

Changes include:
1. Refactor `SubscriberBase`/`SubcriberManager` to support
pre-forward/post_forward hooks.
2. Implement new `ZeROOffloadSubscriber` by re-using DeepSpeed hook
function as much as possible. Since all hook functions are defined in
`DeepSpeedZeRoOffload._register_hooks_recursively` and
`DeepSpeedZeRoOffload.setup_zero_stage3_hooks`, and the good thing is,
the closure is not complex, all hooks are referencing the owning
`DeepSpeedZeRoOffload` instance, so we can create new hook function with
`FunctionType` by binding the owning `DeepSpeedZeRoOffload` instance,
then call the new created function in subscriber's
`pre_forward_module_apply_impl` and `post_forward_module_apply_impl`
interfaces.
3. Monkey patch `DeepSpeedZeRoOffload.setup_zero_stage3_hooks` to
register the `ZeROOffloadSubscriber` for the model, then we don't need
change any code on the DeepSpeed repo (at least so far).
4. Fix the ATen embedding custom symbolic exporter function by
tolerating weights size be (0) (changed by DeepSpeed zero stage 3).

UT will be added once stage3 is fully supported. 

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-08-25 00:15:22 +08:00
pengwa
cd7b3f54da
Allow defining customized PythonOp shape inferer (#17093)
### Allow defining customized PythonOp shape inferer

For `torch.autograd.Function`, we converted it to PythonOp in MSDomain,
there are two places to do shape inferencing for it:

1. in SymbolicShapeInfer, there is one. 
2. in PythonOp op definition. 

For common PythonOp, since we don't know the relation ship between
inputs and outputs, so we only infer the rank from output ranks, and
generate symbolic dimensions for each dim. While this will introduce
many meaningless symbolic dimensions, sometimes blocking our graph
transformers to do op fusion.

This PR provide a way to define custom shape inferencing for
`torch.autograd.Function` we defined, to propagate the original
dimensions across the PythonOp at the best efforts.

But the 2rd one is not covered yet, we could refine that later. Fixing
1st one is enough for ORTModule training/evaluation.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-08-14 09:13:32 +08:00
pengwa
3649376f09
Fix few small bugs (#17019)
### Fix few bugs

1. symbolic shape infer, there is no None check before get length. 
2. Rename PythonOp/PythonOpGrad's attribute `name` to `func_name`,
otherwise, when we use onnx.helper.make_node to create node, `name`
conflicts with node name.
3. Filter shape inference warnings for PythonOp for torch 2.0 or newer. 
4. Close file descriptor for log suppression. Without the fix, two extra
fd is left after the log suppression exit its context.
Before enter log suppression (left), Before exit log suppression (right)

![image](https://github.com/microsoft/onnxruntime/assets/10530022/3cd3057a-59f9-4c89-8359-d9b32c49a17e)
   With the fix, no fd added after context exit.

![image](https://github.com/microsoft/onnxruntime/assets/10530022/03454a8f-ab48-4552-bb9b-293a4f51be67)
2023-08-07 14:01:36 +08:00
Prathik Rao
779fba1666
ORT Cache (#16744)
### Description
<!-- Describe your changes. -->

This PR adds support to cache the exported training/evaluation ONNX
model in `ORTModule`. On future runs, instead of exporting the model
again, we can pick up the model from a location on disc and run
`ORTModule` training/evaluation.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

ORT Training DRI Contribution

---------

Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
Co-authored-by: pengwa <pengwa@microsoft.com>
2023-07-27 09:00:43 -07:00
pengwa
39fca225ea
ORTModule log clean up (#16795)
### ORTModule log clean up

ORTModule log level - WARNING(Default) is for end users; INFO and
VERBOSE is for internal ORT training developers.

Few issues: 
1. ONNX export will output lots of WARNING error message like "The shape
inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing", which is useless for us or end users.

![image](https://github.com/microsoft/onnxruntime/assets/10530022/f2409480-32e1-483d-bd18-f14149f0588d)

3. ORT also print some information like
""CleanUnusedInitializersAndNodeArgs] Removing
initializer","ReverseBFSWithStopGradient] Skip building gradient for",
which is also useless for us or end users most of the time.

![image](https://github.com/microsoft/onnxruntime/assets/10530022/ff3feaf1-3cb2-4392-b087-86b30b72994c)


5. Different ranks output logs and making ORT developers or end users
feels there are too many logs but usually not useful until we need
investigate.

Few improvements for the issues:
1. For ONNX export logs, there are two kinds of logs: a. export verbose
log; b. other logs printed by torch C++ backend. So this PR make
following change:
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout
and stderr (C++ backend)
# INFO -> FULL export verbose log + FILTERED torch other logs from
stdout and stderr (C++ backend)
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other
logs from stdout and stderr (C++ backend)

e.g. for verbose level, print all logs as usually; for info level, print
verbose export log, and filtered logs from torch C++ backend (removing
messages like this "The shape inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing") . For higher level, only log the info on rank 0.

2. For ORT gradient graph build and session creation, also suppress the
message and filtered out the message when log level >=INFO.

3. log level > INFO, then only logs on rank 0 is logged, to have a
cleaner user experience


This is the log for a BLOOM model training after the change: there are
limited of warnings.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/f270b8d5-2944-49d2-a253-c07057d641a0)
2023-07-26 12:42:50 +08:00
Vincent Wang
c07a3b869c
Triton Codegen for ORTModule (#15831)
Fuse connected elementwise and reduce Ops to TritonOp and codegen triton
code to run the kernel.

This PR is co-edited by @wejoncy and @er3x3
2023-07-13 18:17:58 +08:00
pengwa
a49bb85cfe
Manage ORTModule configurations consistently (#16396)
### Manage ORTModule options

Move all env vars that used for feature ON/OFF into runtime options for
consistent managements.


Be noted: the features' switch are assigned in 2 phases: default values,
overwritten by env vars (if specified by users). So env vars take the
highest priority when all 2 phases both given value explicitly for one
feature.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-06-27 19:19:36 +08:00