2023-06-27 11:19:36 +00:00
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# options.py
import os
from enum import IntFlag
from functools import reduce
from logging import Logger
ORTModule log clean up (#16795)
### ORTModule log clean up
ORTModule log level - WARNING(Default) is for end users; INFO and
VERBOSE is for internal ORT training developers.
Few issues:
1. ONNX export will output lots of WARNING error message like "The shape
inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing", which is useless for us or end users.

3. ORT also print some information like
""CleanUnusedInitializersAndNodeArgs] Removing
initializer","ReverseBFSWithStopGradient] Skip building gradient for",
which is also useless for us or end users most of the time.

5. Different ranks output logs and making ORT developers or end users
feels there are too many logs but usually not useful until we need
investigate.
Few improvements for the issues:
1. For ONNX export logs, there are two kinds of logs: a. export verbose
log; b. other logs printed by torch C++ backend. So this PR make
following change:
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout
and stderr (C++ backend)
# INFO -> FULL export verbose log + FILTERED torch other logs from
stdout and stderr (C++ backend)
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other
logs from stdout and stderr (C++ backend)
e.g. for verbose level, print all logs as usually; for info level, print
verbose export log, and filtered logs from torch C++ backend (removing
messages like this "The shape inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing") . For higher level, only log the info on rank 0.
2. For ORT gradient graph build and session creation, also suppress the
message and filtered out the message when log level >=INFO.
3. log level > INFO, then only logs on rank 0 is logged, to have a
cleaner user experience
This is the log for a BLOOM model training after the change: there are
limited of warnings.

2023-07-26 04:42:50 +00:00
from packaging import version
2023-06-27 11:19:36 +00:00
from onnxruntime . capi import _pybind_state as C
from onnxruntime . training import ortmodule
from . _fallback import _FallbackPolicy
from . _logger import LogLevel
ORTModule log clean up (#16795)
### ORTModule log clean up
ORTModule log level - WARNING(Default) is for end users; INFO and
VERBOSE is for internal ORT training developers.
Few issues:
1. ONNX export will output lots of WARNING error message like "The shape
inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing", which is useless for us or end users.

3. ORT also print some information like
""CleanUnusedInitializersAndNodeArgs] Removing
initializer","ReverseBFSWithStopGradient] Skip building gradient for",
which is also useless for us or end users most of the time.

5. Different ranks output logs and making ORT developers or end users
feels there are too many logs but usually not useful until we need
investigate.
Few improvements for the issues:
1. For ONNX export logs, there are two kinds of logs: a. export verbose
log; b. other logs printed by torch C++ backend. So this PR make
following change:
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout
and stderr (C++ backend)
# INFO -> FULL export verbose log + FILTERED torch other logs from
stdout and stderr (C++ backend)
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other
logs from stdout and stderr (C++ backend)
e.g. for verbose level, print all logs as usually; for info level, print
verbose export log, and filtered logs from torch C++ backend (removing
messages like this "The shape inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing") . For higher level, only log the info on rank 0.
2. For ORT gradient graph build and session creation, also suppress the
message and filtered out the message when log level >=INFO.
3. log level > INFO, then only logs on rank 0 is logged, to have a
cleaner user experience
This is the log for a BLOOM model training after the change: there are
limited of warnings.

2023-07-26 04:42:50 +00:00
from . _utils import get_runtime_pytorch_version , parse_os_env_skip_check_flags
2023-06-27 11:19:36 +00:00
class _SaveOnnxOptions :
""" Configurable option to save ORTModule intermediate onnx models. """
# class variable
_path_environment_key = " ORTMODULE_SAVE_ONNX_PATH "
def __init__ ( self , save , name_prefix , path : str ) :
self . _save , self . _name_prefix , self . _path = self . _extract_info ( save , name_prefix , path )
def _extract_info ( self , save , name_prefix , path : str ) :
# get the destination path from os env variable
default_path = path if len ( path ) > 0 else os . getcwd ( )
destination_path = os . getenv ( _SaveOnnxOptions . _path_environment_key , default_path )
# perform validation only when save is True
if save :
self . _validate ( save , name_prefix , destination_path )
return save , name_prefix , destination_path
def _validate ( self , save , name_prefix , destination_path ) :
# check if directory is writable
if not os . access ( destination_path , os . W_OK ) :
raise OSError (
f " Directory { destination_path } is not writable. Please set the "
f " { _SaveOnnxOptions . _path_environment_key } environment variable to a writable path. "
)
# check if input prefix is a string
if not isinstance ( name_prefix , str ) :
raise TypeError ( f " Expected name prefix of type str, got { type ( name_prefix ) } . " )
# if save_onnx is set, save_onnx_prefix must be a non empty string
if not name_prefix :
raise ValueError ( " onnx_prefix must be provided when save_onnx is set. " )
@property
def save ( self ) :
return self . _save
@property
def name_prefix ( self ) :
return self . _name_prefix
@property
def path ( self ) :
return self . _path
class _LoggingOptions :
""" Configurable option to set the log level in ORTModule. """
# class variable
_log_level_environment_key = " ORTMODULE_LOG_LEVEL "
def __init__ ( self , log_level ) :
self . _log_level = self . _extract_info ( log_level )
def _extract_info ( self , log_level ) :
# get the log_level from os env variable
2024-07-11 05:35:08 +00:00
# OS environment variable log level supersedes the locally provided one
2023-06-27 11:19:36 +00:00
self . _validate ( log_level )
log_level = LogLevel [ os . getenv ( _LoggingOptions . _log_level_environment_key , log_level . name ) ]
return log_level
def _validate ( self , log_level ) :
# check if log_level is an instance of LogLevel
if not isinstance ( log_level , LogLevel ) :
raise TypeError ( f " Expected log_level of type LogLevel, got { type ( log_level ) } . " )
@property
def log_level ( self ) - > LogLevel :
return self . _log_level
class DebugOptions :
""" Configurable debugging options for ORTModule.
Args :
log_level ( : obj : ` LogLevel ` , optional ) : Configure ORTModule log level . Defaults to LogLevel . WARNING .
log_level can also be set by setting the environment variable " ORTMODULE_LOG_LEVEL " to one of
" VERBOSE " , " INFO " , " WARNING " , " ERROR " , " FATAL " . In case both are set , the environment variable
takes precedence .
save_onnx ( : obj : ` bool ` , optional ) : Configure ORTModule to save onnx models . Defaults to False .
The output directory of the onnx models by default is set to the current working directory .
To change the output directory , the environment variable " ORTMODULE_SAVE_ONNX_PATH " can be
set to the destination directory path .
onnx_prefix ( : obj : ` str ` , optional ) : Name prefix to the ORTModule ONNX models saved file names .
Must be provided if save_onnx is True
Raises :
OSError : If save_onnx is True and output directory is not writable .
TypeError : If save_onnx is True and name_prefix is not a valid string . Or if
log_level is not an instance of LogLevel .
ValueError : If save_onnx is True and name_prefix is an empty string .
"""
def __init__ ( self , log_level = LogLevel . WARNING , save_onnx = False , onnx_prefix = " " , save_path = " " , config = None ) :
self . log_level = log_level
self . save_onnx = save_onnx
self . onnx_prefix = onnx_prefix
self . _save_onnx_models = _SaveOnnxOptions ( self . save_onnx , self . onnx_prefix , save_path )
self . _logging = _LoggingOptions ( self . log_level )
@property
def save_onnx_models ( self ) :
""" Accessor for the ONNX saving configuration. """
return self . _save_onnx_models
@property
def logging ( self ) :
""" Accessor for the logging configuration. """
return self . _logging
ORTModule log clean up (#16795)
### ORTModule log clean up
ORTModule log level - WARNING(Default) is for end users; INFO and
VERBOSE is for internal ORT training developers.
Few issues:
1. ONNX export will output lots of WARNING error message like "The shape
inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing", which is useless for us or end users.

3. ORT also print some information like
""CleanUnusedInitializersAndNodeArgs] Removing
initializer","ReverseBFSWithStopGradient] Skip building gradient for",
which is also useless for us or end users most of the time.

5. Different ranks output logs and making ORT developers or end users
feels there are too many logs but usually not useful until we need
investigate.
Few improvements for the issues:
1. For ONNX export logs, there are two kinds of logs: a. export verbose
log; b. other logs printed by torch C++ backend. So this PR make
following change:
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout
and stderr (C++ backend)
# INFO -> FULL export verbose log + FILTERED torch other logs from
stdout and stderr (C++ backend)
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other
logs from stdout and stderr (C++ backend)
e.g. for verbose level, print all logs as usually; for info level, print
verbose export log, and filtered logs from torch C++ backend (removing
messages like this "The shape inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing") . For higher level, only log the info on rank 0.
2. For ORT gradient graph build and session creation, also suppress the
message and filtered out the message when log level >=INFO.
3. log level > INFO, then only logs on rank 0 is logged, to have a
cleaner user experience
This is the log for a BLOOM model training after the change: there are
limited of warnings.

2023-07-26 04:42:50 +00:00
@property
def torch_exporter_filter ( self ) :
""" Accessor for the filter export logs configuration. """
2023-08-07 06:01:36 +00:00
torch_version = get_runtime_pytorch_version ( )
Memory optimization refactor and refinement (#17481)
### Memory optimization refactor and refinement
Currently memory optimizer runs graph transformations and print
recompute opportunities in INFO level, while ORT backend has many many
INFO level logs making users hard to find those information. So we are
looking for a Python binding API to retrieve the memory optimization
opportunities instead of depending on the MemoryOptimizer's default
logging.
Then we can print ORTModule feature statistics using this information.
Also, with such an API, we can create an ORT session created, where
allocation plan is done, the analysis will consider buffer reuse as
well. This can void giving some recomputation subgraphs that are reusing
other subgraphs' output buffers.
Check
https://github.com/microsoft/onnxruntime/blob/pengwa/add_devinfo_level/docs/Memory_Optimizer.md
for the new flow using `MemoryOptimizer`.
This pull requests made following refactoring:
1. Print the log in ORTModule Python script, along with ORTModule
feature enabling stats. This is implemented by exposing an API
`get_serialized_ortmodule_memory_stat` to retrieve the memory
optimization opportunities.
2. We are analyzing memory optimization opportunities considering ORT
memory planning. This is done by firstly creating the execution graph
without enabling MemoryOptimizer, then we call
`execution_agent.get_serialized_ortmodule_memory_stat` which internally
will consider the session memory allocation planner when analyzing
memory optimization opportunity. As a direct result, the memory
optimization opportunities can show those stashed activations that are
reusing other buffers.
3. Move recompute analysis logic from memory_optimizer.h/cc to
recompute_analysis.h/cc.
4. Abstract optimization strategies for their own implementation. This
will make introducing new strategies (for example compression and
decompression ) easier.
New logging matrix (INFO Level), in WARNING level, the details will NOT
show.
```
2023-09-13 13:25:09,249 orttraining.rank-0 [WARNING] -
***** ONNX Runtime Training (ORTModule) is accelerating your model *****
ORTModule is enabled with following features ON/OFF for [training] mode:
ATen Executor : ON : Dispatch ATen operators to ORT's ATen executor
Cast Propagation : ON : Level 1 enabled
Custom Function : ON : Support custom torch.autograd.Function export and execution
Memory Optimizer : ON : RecomputeConfig: Reshape+Where+BiasSoftmax+:1:-1,Cast+:1:-1, ProbeLevel: 1, available configs:
Config Freq Saving(B) Saving Symbolic(Bytes)
- Plan 1 : ON : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : ON : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0)
- Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
Compute Optimizer : ON : Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0
- FLOPReduction : ON : Reduce FLOPs by upstreaming shrinking-sized ops
Auto Fallback : ON : Fallback to PyTorch when encountering unsupported ops
TritonOp Enabled : OFF : ORT will switch to Triton for executing some ops to further accelerate training.
ZeRO Stage3 Support : OFF : Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0
Total ORT initialization overhead is 10.73s where export takes 8.39s.
Other overhead details: graph builder init takes 0.06s, runtime detection takes 0.01s, graph building takes 0.31s, session creation takes 1.96s
Versions: ONNX Runtime - 1.16.0+cu118, ONNX - 1.11.0
Note 1: use comma to enable multiple plans at the same time.
export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Note 2: saving is calculated based on the 1st batch symbolic dim values:
inputs_input_ids_dim0=1,
inputs_input_ids_dim1=1024,
inputs_attention_mask_dim0=1,
inputs_attention_mask_dim1=1024,
inputs_labels_dim0=1,
inputs_labels_dim1=1024,
************************************************************************
```
If DEVINFO level is enabled, then more details about the memory
optimizations are printed.
```
MemoryInsight Summary - User config: BiasGelu+:1:-1,Cast+:2:-1
==========================================================================================================================================
|Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|3 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 |
| | Stashed Activations: |
| | - ReuseFreq : Output 0(3), |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 32 x 240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+:1:-1 |
| | Stashed Activations: |
| | - ReuseFreq : Output 0(2), |
| | - Output 0 : [ x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Where+BiasSoftmax+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+BiasSoftmax+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasGelu+ |
| | Status : Enabled, requested count=-1, actual applied count=2 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+Add+FusedMatMul+Add+Add+Add+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Where+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
| | |
| |>>Option 2 : RecomputeWithCompromise subgraph Cast+ |
| | Status : Enabled, requested count=-1, actual applied count=1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 50% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasSoftmax+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=BiasSoftmax+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasGelu+ |
| | Status : Enabled, requested count=-1, actual applied count=1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Add+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Add+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
==========================================================================================================================================
Note: use comma as a separator for enabling more than one subgraphs.
************************************************************************
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-11-23 03:39:00 +00:00
if self . log_level > LogLevel . DEVINFO :
2023-08-07 06:01:36 +00:00
if torch_version < version . parse ( " 2.0 " ) :
return [
# WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
# WARNING: The shape inference of com.microsoft::PythonOp type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
# WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
# WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
" type is missing, so it may result in wrong shape inference " ,
# Warning: Checker does not support models with experimental ops: ATen
" Checker does not support models with experimental ops: " ,
" Dropout is a training op and should not be exported in inference mode. " ,
# Warning: Shape inference does not support models with experimental operators: ATen
" Shape inference does not support models with experimental operators: " ,
# Warning: Unsupported operator Trilu. No schema registered for this operator.
# Warning: Unsupported operator ATen. No schema registered for this operator.
# Warning: Unsupported operator SoftmaxCrossEntropyLossInternal. No schema registered for this operator.
" No schema registered for this operator. " ,
]
ORTModule log clean up (#16795)
### ORTModule log clean up
ORTModule log level - WARNING(Default) is for end users; INFO and
VERBOSE is for internal ORT training developers.
Few issues:
1. ONNX export will output lots of WARNING error message like "The shape
inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing", which is useless for us or end users.

3. ORT also print some information like
""CleanUnusedInitializersAndNodeArgs] Removing
initializer","ReverseBFSWithStopGradient] Skip building gradient for",
which is also useless for us or end users most of the time.

5. Different ranks output logs and making ORT developers or end users
feels there are too many logs but usually not useful until we need
investigate.
Few improvements for the issues:
1. For ONNX export logs, there are two kinds of logs: a. export verbose
log; b. other logs printed by torch C++ backend. So this PR make
following change:
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout
and stderr (C++ backend)
# INFO -> FULL export verbose log + FILTERED torch other logs from
stdout and stderr (C++ backend)
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other
logs from stdout and stderr (C++ backend)
e.g. for verbose level, print all logs as usually; for info level, print
verbose export log, and filtered logs from torch C++ backend (removing
messages like this "The shape inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing") . For higher level, only log the info on rank 0.
2. For ORT gradient graph build and session creation, also suppress the
message and filtered out the message when log level >=INFO.
3. log level > INFO, then only logs on rank 0 is logged, to have a
cleaner user experience
This is the log for a BLOOM model training after the change: there are
limited of warnings.

2023-07-26 04:42:50 +00:00
return [
2023-08-07 06:01:36 +00:00
# [W shape_type_inference.cpp:1974] Warning: The shape inference of com.microsoft::PythonOp type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
ORTModule log clean up (#16795)
### ORTModule log clean up
ORTModule log level - WARNING(Default) is for end users; INFO and
VERBOSE is for internal ORT training developers.
Few issues:
1. ONNX export will output lots of WARNING error message like "The shape
inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing", which is useless for us or end users.

3. ORT also print some information like
""CleanUnusedInitializersAndNodeArgs] Removing
initializer","ReverseBFSWithStopGradient] Skip building gradient for",
which is also useless for us or end users most of the time.

5. Different ranks output logs and making ORT developers or end users
feels there are too many logs but usually not useful until we need
investigate.
Few improvements for the issues:
1. For ONNX export logs, there are two kinds of logs: a. export verbose
log; b. other logs printed by torch C++ backend. So this PR make
following change:
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout
and stderr (C++ backend)
# INFO -> FULL export verbose log + FILTERED torch other logs from
stdout and stderr (C++ backend)
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other
logs from stdout and stderr (C++ backend)
e.g. for verbose level, print all logs as usually; for info level, print
verbose export log, and filtered logs from torch C++ backend (removing
messages like this "The shape inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing") . For higher level, only log the info on rank 0.
2. For ORT gradient graph build and session creation, also suppress the
message and filtered out the message when log level >=INFO.
3. log level > INFO, then only logs on rank 0 is logged, to have a
cleaner user experience
This is the log for a BLOOM model training after the change: there are
limited of warnings.

2023-07-26 04:42:50 +00:00
" type is missing, so it may result in wrong shape inference " ,
Allow defining customized PythonOp shape inferer (#17093)
### Allow defining customized PythonOp shape inferer
For `torch.autograd.Function`, we converted it to PythonOp in MSDomain,
there are two places to do shape inferencing for it:
1. in SymbolicShapeInfer, there is one.
2. in PythonOp op definition.
For common PythonOp, since we don't know the relation ship between
inputs and outputs, so we only infer the rank from output ranks, and
generate symbolic dimensions for each dim. While this will introduce
many meaningless symbolic dimensions, sometimes blocking our graph
transformers to do op fusion.
This PR provide a way to define custom shape inferencing for
`torch.autograd.Function` we defined, to propagate the original
dimensions across the PythonOp at the best efforts.
But the 2rd one is not covered yet, we could refine that later. Fixing
1st one is enough for ORTModule training/evaluation.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-08-14 01:13:32 +00:00
# diagnostics [WARNING] - None
" [WARNING] - None " ,
ORTModule log clean up (#16795)
### ORTModule log clean up
ORTModule log level - WARNING(Default) is for end users; INFO and
VERBOSE is for internal ORT training developers.
Few issues:
1. ONNX export will output lots of WARNING error message like "The shape
inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing", which is useless for us or end users.

3. ORT also print some information like
""CleanUnusedInitializersAndNodeArgs] Removing
initializer","ReverseBFSWithStopGradient] Skip building gradient for",
which is also useless for us or end users most of the time.

5. Different ranks output logs and making ORT developers or end users
feels there are too many logs but usually not useful until we need
investigate.
Few improvements for the issues:
1. For ONNX export logs, there are two kinds of logs: a. export verbose
log; b. other logs printed by torch C++ backend. So this PR make
following change:
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout
and stderr (C++ backend)
# INFO -> FULL export verbose log + FILTERED torch other logs from
stdout and stderr (C++ backend)
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other
logs from stdout and stderr (C++ backend)
e.g. for verbose level, print all logs as usually; for info level, print
verbose export log, and filtered logs from torch C++ backend (removing
messages like this "The shape inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing") . For higher level, only log the info on rank 0.
2. For ORT gradient graph build and session creation, also suppress the
message and filtered out the message when log level >=INFO.
3. log level > INFO, then only logs on rank 0 is logged, to have a
cleaner user experience
This is the log for a BLOOM model training after the change: there are
limited of warnings.

2023-07-26 04:42:50 +00:00
]
return None
@property
def onnxruntime_log_filter ( self ) :
""" Accessor for the filter onnxruntime logs configuration. """
return None
2023-06-27 11:19:36 +00:00
class _SkipCheck ( IntFlag ) :
""" Enumeration to specify which checks should be skipped, allowing faster execution """
SKIP_CHECK_DISABLED = 1
SKIP_CHECK_DEVICE = 2
SKIP_CHECK_BUILD_GRADIENT = 4
SKIP_CHECK_EXECUTION_AGENT = 8
def is_set ( self , check ) :
""" Check whether `check` is set on the `_SkipCheck instance
SKIP_CHECK_DISABLED implies the check will return False
"""
return not _SkipCheck . is_disabled ( self ) and check in self
def is_disabled ( self ) :
""" Check whether `_SkipCheck.SKIP_CHECK_DISABLED is set on the `_SkipCheck instance """
return _SkipCheck . SKIP_CHECK_DISABLED in self
2023-12-12 00:44:05 +00:00
class _MemoryOptimizationLevel ( IntFlag ) :
""" Enumeration to specify memory optimization level """
USER_SPECIFIED = 0 # Fully respect user-specified config
2024-03-06 02:06:25 +00:00
TRANSFORMER_LAYERWISE_RECOMPUTE = (
2024-07-11 05:35:08 +00:00
1 # Enable all recomputable subgraphs (excluding compromised recomputable graphs) per layer
2024-03-06 02:06:25 +00:00
)
TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE = 2 # Enable all recomputable subgraphs per layer
2023-12-12 00:44:05 +00:00
@staticmethod
def to_string ( memory_optimization_level ) :
if memory_optimization_level == _MemoryOptimizationLevel . USER_SPECIFIED :
return " USER_SPECIFIED "
if memory_optimization_level == _MemoryOptimizationLevel . TRANSFORMER_LAYERWISE_RECOMPUTE :
return " TRANSFORMER_LAYERWISE_RECOMPUTE "
2024-03-06 02:06:25 +00:00
if memory_optimization_level == _MemoryOptimizationLevel . TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE :
return " TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE "
2023-12-12 00:44:05 +00:00
return " "
2023-06-27 11:19:36 +00:00
class _RuntimeOptions :
""" Configurable runtime options for ORTModule. """
def __init__ ( self , logger : Logger ) :
""" Constructor for Options.
Initially set all the options to their default values , then override them with the values
from the environment variables .
"""
self . _logger = logger
self . onnx_opset_version = ortmodule . ONNX_OPSET_VERSION
self . conv_algo_search = " HEURISTIC "
# Configuration for cast optimization.
# Specify cast propagation strategy. Currently, three strategies are available:
# NONE, INSERT-AND-REDUCE and FLOOD-FILL
# The default is FLOOD_FILL, expand FP16 computation regions in the graph using
# allowed opcodes for the given level.
self . propagate_cast_ops_strategy = C . PropagateCastOpsStrategy . FLOOD_FILL
# Optimize by moving Cast operations if propagate_cast_ops_level is non-negative.
# - If the propagate_cast_ops_level is set to zero, then the transformation considers only the opcodes
# specified by propagate_cast_ops_allow as "FP16 safe", to insert/(re)move cast operations before/after
# to perform such operations in reduced (16-bit) precision.
# - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by
# propagate_cast_ops_allow, use onnxruntime predetermined list of opcodes considered safe to move
# before/after the cast operation.
# - Onnxruntime Level 1 predetermined "FP16 safe" opcodes include only opcodes that do not perform
# any computation such as Transpose, Split, Reshape, etc., or the computation is actually in Float
# such as GeLU, etc.
# - Whereas Level 2 predetermined "FP16 safe" opcodes include opcodes that perform computation using
# contrib ops, Dropout, LayerNormalization, etc.
self . propagate_cast_ops_level = 1
# List of opcodes to be considered safe to move before/after the cast operation if propagate_cast_ops_level
# is zero.
self . propagate_cast_ops_allow = [ ]
# default execution order is priority-based for both dynamic/static shape input for now
# if we observe the benefit of static shape, we can expose this flag to the user
self . use_static_shape = False
# flag to enable symbolic shape inference for dynamic shape inputs to improve performance
self . run_symbolic_shape_infer = True
# PyTorch custom Autograd function support
from . _custom_autograd_function import custom_autograd_function_enabler
self . enable_custom_autograd_function = custom_autograd_function_enabler . state
self . use_external_gpu_allocator = True
# WIP feature to enable caching in Gradient accumulation scenario.
self . enable_grad_acc_optimization = False
# Memory-aware gradient builder.
self . use_memory_efficient_gradient = False
# Configuration for compute optimization.
self . enable_compute_optimizer = True
2024-05-10 13:55:43 +00:00
self . enable_embedding_sparse_optimizer = True
self . enable_label_sparse_optimizer = True
2023-06-27 11:19:36 +00:00
self . label_sparsity_ratio = " "
self . embed_sparsity_ratio = " "
# Configuration for memory optimization.
2023-12-12 00:44:05 +00:00
self . memory_optimization_level = (
_MemoryOptimizationLevel . USER_SPECIFIED
2024-05-21 05:38:19 +00:00
) # 0: use `memory_optimizer_config_file_path`; 1: aggressive optimization, enable all recomputable subgraphs.
self . memory_optimizer_config_file_path = (
" " # This is an advanced config, please refer to onnxruntime docs for details.
)
2023-12-12 00:44:05 +00:00
# 1 is the op set level; 0 indicates whether consider the Transformer-based model's layer boundary when
# detecting recompute subgraphs.
self . recompute_probe_config = " 1:0 "
2023-06-27 11:19:36 +00:00
# Configuration for dev tools.
self . print_input_density = False
Memory optimization refactor and refinement (#17481)
### Memory optimization refactor and refinement
Currently memory optimizer runs graph transformations and print
recompute opportunities in INFO level, while ORT backend has many many
INFO level logs making users hard to find those information. So we are
looking for a Python binding API to retrieve the memory optimization
opportunities instead of depending on the MemoryOptimizer's default
logging.
Then we can print ORTModule feature statistics using this information.
Also, with such an API, we can create an ORT session created, where
allocation plan is done, the analysis will consider buffer reuse as
well. This can void giving some recomputation subgraphs that are reusing
other subgraphs' output buffers.
Check
https://github.com/microsoft/onnxruntime/blob/pengwa/add_devinfo_level/docs/Memory_Optimizer.md
for the new flow using `MemoryOptimizer`.
This pull requests made following refactoring:
1. Print the log in ORTModule Python script, along with ORTModule
feature enabling stats. This is implemented by exposing an API
`get_serialized_ortmodule_memory_stat` to retrieve the memory
optimization opportunities.
2. We are analyzing memory optimization opportunities considering ORT
memory planning. This is done by firstly creating the execution graph
without enabling MemoryOptimizer, then we call
`execution_agent.get_serialized_ortmodule_memory_stat` which internally
will consider the session memory allocation planner when analyzing
memory optimization opportunity. As a direct result, the memory
optimization opportunities can show those stashed activations that are
reusing other buffers.
3. Move recompute analysis logic from memory_optimizer.h/cc to
recompute_analysis.h/cc.
4. Abstract optimization strategies for their own implementation. This
will make introducing new strategies (for example compression and
decompression ) easier.
New logging matrix (INFO Level), in WARNING level, the details will NOT
show.
```
2023-09-13 13:25:09,249 orttraining.rank-0 [WARNING] -
***** ONNX Runtime Training (ORTModule) is accelerating your model *****
ORTModule is enabled with following features ON/OFF for [training] mode:
ATen Executor : ON : Dispatch ATen operators to ORT's ATen executor
Cast Propagation : ON : Level 1 enabled
Custom Function : ON : Support custom torch.autograd.Function export and execution
Memory Optimizer : ON : RecomputeConfig: Reshape+Where+BiasSoftmax+:1:-1,Cast+:1:-1, ProbeLevel: 1, available configs:
Config Freq Saving(B) Saving Symbolic(Bytes)
- Plan 1 : ON : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : ON : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0)
- Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
Compute Optimizer : ON : Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0
- FLOPReduction : ON : Reduce FLOPs by upstreaming shrinking-sized ops
Auto Fallback : ON : Fallback to PyTorch when encountering unsupported ops
TritonOp Enabled : OFF : ORT will switch to Triton for executing some ops to further accelerate training.
ZeRO Stage3 Support : OFF : Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0
Total ORT initialization overhead is 10.73s where export takes 8.39s.
Other overhead details: graph builder init takes 0.06s, runtime detection takes 0.01s, graph building takes 0.31s, session creation takes 1.96s
Versions: ONNX Runtime - 1.16.0+cu118, ONNX - 1.11.0
Note 1: use comma to enable multiple plans at the same time.
export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Note 2: saving is calculated based on the 1st batch symbolic dim values:
inputs_input_ids_dim0=1,
inputs_input_ids_dim1=1024,
inputs_attention_mask_dim0=1,
inputs_attention_mask_dim1=1024,
inputs_labels_dim0=1,
inputs_labels_dim1=1024,
************************************************************************
```
If DEVINFO level is enabled, then more details about the memory
optimizations are printed.
```
MemoryInsight Summary - User config: BiasGelu+:1:-1,Cast+:2:-1
==========================================================================================================================================
|Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|3 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 |
| | Stashed Activations: |
| | - ReuseFreq : Output 0(3), |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 32 x 240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+:1:-1 |
| | Stashed Activations: |
| | - ReuseFreq : Output 0(2), |
| | - Output 0 : [ x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Where+BiasSoftmax+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+BiasSoftmax+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasGelu+ |
| | Status : Enabled, requested count=-1, actual applied count=2 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+Add+FusedMatMul+Add+Add+Add+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Where+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
| | |
| |>>Option 2 : RecomputeWithCompromise subgraph Cast+ |
| | Status : Enabled, requested count=-1, actual applied count=1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 50% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasSoftmax+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=BiasSoftmax+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasGelu+ |
| | Status : Enabled, requested count=-1, actual applied count=1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Add+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Add+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
==========================================================================================================================================
Note: use comma as a separator for enabling more than one subgraphs.
************************************************************************
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-11-23 03:39:00 +00:00
self . print_memory_stat_by_step = False
2023-06-27 11:19:36 +00:00
# Configuration for fallback.
self . fallback_policy = ortmodule . ORTMODULE_FALLBACK_POLICY
# Configuration for skip check.
# Indicators of some logic have been executed previously and thus could be skipped for faster training
# default is enabled, if not defined in os env
self . skip_check = _SkipCheck (
_SkipCheck . SKIP_CHECK_DEVICE | _SkipCheck . SKIP_CHECK_BUILD_GRADIENT | _SkipCheck . SKIP_CHECK_EXECUTION_AGENT
)
2023-07-13 10:17:58 +00:00
# Triton support.
self . enable_triton = False
self . enable_tuning = False
self . max_tuning_duration_ms = 0
self . tuning_results_path = " "
2023-07-27 16:00:43 +00:00
# Cache exported model
self . ortmodule_cache_dir = " "
2023-08-24 16:15:22 +00:00
# Experimental features.
self . enable_zero_stage3_support = False # Once enabled, cannot be disabled.
2024-01-16 00:57:37 +00:00
# We disable memory efficient grad management by default, will enable once it's fully validated.
self . enable_mem_efficient_grad_management = False
2023-12-05 20:41:17 +00:00
self . deepcopy_before_model_export = True
2023-06-27 11:19:36 +00:00
# Override the feature config if it exists in os env.
self . _override_from_env_vars ( )
def _override_from_env_vars ( self ) :
self . onnx_opset_version = int ( os . getenv ( " ORTMODULE_ONNX_OPSET_VERSION " , self . onnx_opset_version ) )
self . conv_algo_search = os . getenv ( " ORTMODULE_CONV_ALGO_SEARCH " , self . conv_algo_search )
if self . conv_algo_search not in [ " HEURISTIC " , " EXHAUSTIVE " ] :
self . _logger . warning ( " Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE. " )
self . conv_algo_search = " HEURISTIC "
# Configuration for compute optimization.
compute_optimizer_reset = False
if " ORTMODULE_ENABLE_COMPUTE_OPTIMIZER " in os . environ :
self . enable_compute_optimizer = int ( os . getenv ( " ORTMODULE_ENABLE_COMPUTE_OPTIMIZER " ) ) == 1
compute_optimizer_reset = True
2024-05-10 13:55:43 +00:00
if " ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER " in os . environ or compute_optimizer_reset :
if " ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER " in os . environ :
self . enable_label_sparse_optimizer = int ( os . getenv ( " ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER " ) ) == 1
self . enable_label_sparse_optimizer = self . enable_compute_optimizer and self . enable_label_sparse_optimizer
2023-06-27 11:19:36 +00:00
2024-05-10 13:55:43 +00:00
if " ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER " in os . environ or compute_optimizer_reset :
if " ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER " in os . environ :
self . enable_embedding_sparse_optimizer = (
int ( os . getenv ( " ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER " ) ) == 1
)
2023-06-27 11:19:36 +00:00
self . enable_embedding_sparse_optimizer = (
2024-05-10 13:55:43 +00:00
self . enable_compute_optimizer and self . enable_embedding_sparse_optimizer
2023-06-27 11:19:36 +00:00
)
# Configuration for memory optimization.
2023-12-12 00:44:05 +00:00
self . memory_optimization_level = int ( os . getenv ( " ORTMODULE_MEMORY_OPT_LEVEL " , self . memory_optimization_level ) )
2024-05-21 05:38:19 +00:00
self . memory_optimizer_config_file_path = os . getenv (
" ORTMODULE_MEMORY_OPT_CONFIG " , self . memory_optimizer_config_file_path
)
2024-03-06 02:06:25 +00:00
if self . memory_optimization_level in [
_MemoryOptimizationLevel . TRANSFORMER_LAYERWISE_RECOMPUTE ,
_MemoryOptimizationLevel . TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE ,
] :
2023-12-12 00:44:05 +00:00
# For transformer layer-wise recompute, we enable layer boundary when detecting subgraphs.
# Then all detected subgraphs will not cross different layers.
self . recompute_probe_config = " 1:1 "
2023-06-27 11:19:36 +00:00
# Configuration for dev tools.
if " ORTMODULE_PRINT_INPUT_DENSITY " in os . environ :
self . print_input_density = int ( os . getenv ( " ORTMODULE_PRINT_INPUT_DENSITY " ) ) == 1
if " ORTMODULE_PRINT_MEMORY_STATS " in os . environ :
Memory optimization refactor and refinement (#17481)
### Memory optimization refactor and refinement
Currently memory optimizer runs graph transformations and print
recompute opportunities in INFO level, while ORT backend has many many
INFO level logs making users hard to find those information. So we are
looking for a Python binding API to retrieve the memory optimization
opportunities instead of depending on the MemoryOptimizer's default
logging.
Then we can print ORTModule feature statistics using this information.
Also, with such an API, we can create an ORT session created, where
allocation plan is done, the analysis will consider buffer reuse as
well. This can void giving some recomputation subgraphs that are reusing
other subgraphs' output buffers.
Check
https://github.com/microsoft/onnxruntime/blob/pengwa/add_devinfo_level/docs/Memory_Optimizer.md
for the new flow using `MemoryOptimizer`.
This pull requests made following refactoring:
1. Print the log in ORTModule Python script, along with ORTModule
feature enabling stats. This is implemented by exposing an API
`get_serialized_ortmodule_memory_stat` to retrieve the memory
optimization opportunities.
2. We are analyzing memory optimization opportunities considering ORT
memory planning. This is done by firstly creating the execution graph
without enabling MemoryOptimizer, then we call
`execution_agent.get_serialized_ortmodule_memory_stat` which internally
will consider the session memory allocation planner when analyzing
memory optimization opportunity. As a direct result, the memory
optimization opportunities can show those stashed activations that are
reusing other buffers.
3. Move recompute analysis logic from memory_optimizer.h/cc to
recompute_analysis.h/cc.
4. Abstract optimization strategies for their own implementation. This
will make introducing new strategies (for example compression and
decompression ) easier.
New logging matrix (INFO Level), in WARNING level, the details will NOT
show.
```
2023-09-13 13:25:09,249 orttraining.rank-0 [WARNING] -
***** ONNX Runtime Training (ORTModule) is accelerating your model *****
ORTModule is enabled with following features ON/OFF for [training] mode:
ATen Executor : ON : Dispatch ATen operators to ORT's ATen executor
Cast Propagation : ON : Level 1 enabled
Custom Function : ON : Support custom torch.autograd.Function export and execution
Memory Optimizer : ON : RecomputeConfig: Reshape+Where+BiasSoftmax+:1:-1,Cast+:1:-1, ProbeLevel: 1, available configs:
Config Freq Saving(B) Saving Symbolic(Bytes)
- Plan 1 : ON : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : ON : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0)
- Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
Compute Optimizer : ON : Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0
- FLOPReduction : ON : Reduce FLOPs by upstreaming shrinking-sized ops
Auto Fallback : ON : Fallback to PyTorch when encountering unsupported ops
TritonOp Enabled : OFF : ORT will switch to Triton for executing some ops to further accelerate training.
ZeRO Stage3 Support : OFF : Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0
Total ORT initialization overhead is 10.73s where export takes 8.39s.
Other overhead details: graph builder init takes 0.06s, runtime detection takes 0.01s, graph building takes 0.31s, session creation takes 1.96s
Versions: ONNX Runtime - 1.16.0+cu118, ONNX - 1.11.0
Note 1: use comma to enable multiple plans at the same time.
export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Note 2: saving is calculated based on the 1st batch symbolic dim values:
inputs_input_ids_dim0=1,
inputs_input_ids_dim1=1024,
inputs_attention_mask_dim0=1,
inputs_attention_mask_dim1=1024,
inputs_labels_dim0=1,
inputs_labels_dim1=1024,
************************************************************************
```
If DEVINFO level is enabled, then more details about the memory
optimizations are printed.
```
MemoryInsight Summary - User config: BiasGelu+:1:-1,Cast+:2:-1
==========================================================================================================================================
|Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|3 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 |
| | Stashed Activations: |
| | - ReuseFreq : Output 0(3), |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 32 x 240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+:1:-1 |
| | Stashed Activations: |
| | - ReuseFreq : Output 0(2), |
| | - Output 0 : [ x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Where+BiasSoftmax+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+BiasSoftmax+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasGelu+ |
| | Status : Enabled, requested count=-1, actual applied count=2 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|2 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+Add+FusedMatMul+Add+Add+Add+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Where+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
| | |
| |>>Option 2 : RecomputeWithCompromise subgraph Cast+ |
| | Status : Enabled, requested count=-1, actual applied count=1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 50% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasSoftmax+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=BiasSoftmax+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph BiasGelu+ |
| | Status : Enabled, requested count=-1, actual applied count=1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Add+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Add+:1:-1 |
| | Stashed Activations: |
| | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 2560 x ], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |
==========================================================================================================================================
Note: use comma as a separator for enabling more than one subgraphs.
************************************************************************
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-11-23 03:39:00 +00:00
self . print_memory_stat_by_step = int ( os . getenv ( " ORTMODULE_PRINT_MEMORY_STATS " ) ) == 1
2023-06-27 11:19:36 +00:00
# Configuration for fallback.
if " ORTMODULE_FALLBACK_POLICY " in os . environ :
self . fallback_policy = os . getenv ( " ORTMODULE_FALLBACK_POLICY " )
if isinstance ( self . fallback_policy , str ) :
self . fallback_policy = _FallbackPolicy [ self . fallback_policy ]
# Configuration for skip check.
if " ORTMODULE_SKIPCHECK_POLICY " in os . environ :
self . skip_check = reduce (
lambda x , y : x | y ,
[ _SkipCheck [ name ] for name in parse_os_env_skip_check_flags ( " ORTMODULE_SKIPCHECK_POLICY " ) ] ,
)
2023-07-13 10:17:58 +00:00
# Configuration for Triton.
# Enable Triton op executor if Triton is installed, backend has support and environment variable is set.
if (
" ORTMODULE_USE_TRITON " in os . environ
and int ( os . getenv ( " ORTMODULE_USE_TRITON " ) ) == 1
and C . is_triton_enabled ( )
) :
try :
import triton # noqa: F401
except ImportError :
2024-02-01 23:25:33 +00:00
self . _logger . warning (
" triton library missing. Please install triton with `pip install triton`. Triton feature will be off. "
)
2023-07-13 10:17:58 +00:00
else :
self . enable_triton = True
if " ORTMODULE_ENABLE_TUNING " in os . environ and int ( os . getenv ( " ORTMODULE_ENABLE_TUNING " ) ) == 1 :
self . enable_tuning = True
if " ORTMODULE_MAX_TUNING_DURATION_MS " in os . environ :
max_tuning_duration_ms = int ( os . getenv ( " ORTMODULE_MAX_TUNING_DURATION_MS " ) )
if max_tuning_duration_ms > 0 :
self . max_tuning_duration_ms = max_tuning_duration_ms
if " ORTMODULE_TUNING_RESULTS_PATH " in os . environ :
self . tuning_results_path = os . getenv ( " ORTMODULE_TUNING_RESULTS_PATH " )
2023-07-27 16:00:43 +00:00
# Cache exported model
if " ORTMODULE_CACHE_DIR " in os . environ :
2023-11-02 16:46:11 +00:00
self . _logger . warning ( " ORTModule optimization for caching exported model is ON. " )
2023-07-27 16:00:43 +00:00
self . ortmodule_cache_dir = os . getenv ( " ORTMODULE_CACHE_DIR " )
2023-08-24 16:15:22 +00:00
# Experimental features.
if " ORTMODULE_ENABLE_ZERO_STAGE3 " in os . environ and int ( os . getenv ( " ORTMODULE_ENABLE_ZERO_STAGE3 " ) ) == 1 :
self . enable_zero_stage3_support = True
2023-12-05 20:41:17 +00:00
2024-01-16 00:57:37 +00:00
if " ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT " in os . environ :
enable_grad_mgmt = int ( os . getenv ( " ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT " ) )
self . enable_mem_efficient_grad_management = enable_grad_mgmt == 1 and self . enable_custom_autograd_function
if not self . enable_custom_autograd_function and enable_grad_mgmt == 1 :
self . _logger . warning (
" ORTModule optimization for memory efficient gradient management cannot be enabled "
" because PyTorch custom autograd function support is disabled. "
)
2023-12-05 20:41:17 +00:00
if " ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT " in os . environ :
self . deepcopy_before_model_export = int ( os . getenv ( " ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT " ) ) == 1
2024-01-11 04:50:55 +00:00
def memory_optimizer_is_enabled ( self ) - > bool :
""" Check whether memory optimizer is enabled. """
if self . memory_optimization_level == _MemoryOptimizationLevel . USER_SPECIFIED :
2024-05-21 05:38:19 +00:00
return len ( self . memory_optimizer_config_file_path ) > 0
2024-03-06 02:06:25 +00:00
elif self . memory_optimization_level in [
_MemoryOptimizationLevel . TRANSFORMER_LAYERWISE_RECOMPUTE ,
_MemoryOptimizationLevel . TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE ,
] :
2024-01-11 04:50:55 +00:00
return True
return False