mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Prompt layer-wise recompute when applicable (#20126)
### Prompt layer-wise when applicable Give explicit prompts in export failures to users to enable layer-wise memory optimization if we found the checkpoint function is used. - Using checkpoint function is a strong indicator that the model is too large to fit in GPU memory. - If we don't override the checkpoint function here, mostly ONNX export will be failed. 1. For old version PyTorch, when handling gradient checkpoint feature, we just throw an exception. 2. For new version PyTorch, an export failure happens. - But both failures did not give users explicitly "HOW" to mitigate. This PR did that. ``  ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
14d7872ce9
commit
280b2634c5
6 changed files with 342 additions and 64 deletions
|
|
@ -30,10 +30,10 @@ Integrate models using `ORTModule`.
|
|||
```
|
||||
|
||||
There are two modes to enable the memory optimizations:
|
||||
- Aggressively Recompute All Within Each Transformer Layer, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. This will recompute all detected subgraphs within each Transformer Attention+MLP layer. It is easy to enable, but be noted this recompute plan may NOT be the best one. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected.
|
||||
- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans.
|
||||
- Transformer layerwise recompute, e.g. aggressively recompute all supported nodes within each transformer layer (usually including attention and mlp sublayers), enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected.
|
||||
- Manual selected subgraph recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans.
|
||||
|
||||
### Mode 1 - Simple Usage (Aggressively Recompute All Within Each Transformer Layer)
|
||||
### Mode 1 - Simple Usage (Transformer Layerwise Recompute)
|
||||
|
||||
|
||||
1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1`
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
|
@ -21,7 +23,10 @@ if not is_ortmodule_available():
|
|||
raise ImportError("ORTModule is not supported on this platform.")
|
||||
|
||||
|
||||
def _defined_from_envvar(name, default_value, warn=True):
|
||||
def _defined_from_envvar(name: str, default_value: any, warn: bool = True):
|
||||
"""Check given name exists in the environment variable and return the value using the default_value's
|
||||
type if it exists.
|
||||
"""
|
||||
new_value = os.getenv(name, None)
|
||||
if new_value is None:
|
||||
return default_value
|
||||
|
|
@ -34,6 +39,86 @@ def _defined_from_envvar(name, default_value, warn=True):
|
|||
return new_value
|
||||
|
||||
|
||||
def _override_gradient_checkpoint(original_checkpoint):
|
||||
"""
|
||||
Best effort to override `torch.utils.checkpoint` and `deepspeed.checkpointing.checkpoint` during ONNX export.
|
||||
|
||||
Despite importing `torch.utils.checkpoint` or `deepspeed.checkpointing.checkpoint` in `__init__.py`,
|
||||
users might import it first, causing our override to not take effect. We still attempt to override
|
||||
it to work in most cases.
|
||||
|
||||
We replace the checkpoint function with our implementation, without condition checks.
|
||||
The actual check is in the overridden function, verifying if:
|
||||
1) `checkpoint` is called during ORTModule model export,
|
||||
2) Gradient checkpoint autograd function is disallowed (ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT),
|
||||
3) Memory optimization level is not specified by the user (ORTMODULE_MEMORY_OPT_LEVEL).
|
||||
If true, we reset memory optimization to layer-wise recompute.
|
||||
|
||||
"""
|
||||
|
||||
# Note: The `torch.utils.checkpoint` checkpoint function signature looks like below:
|
||||
# `checkpoint(function, *args,
|
||||
# use_reentrant = None,
|
||||
# context_fn = noop_context_fn,
|
||||
# determinism_check = _DEFAULT_DETERMINISM_MODE,
|
||||
# debug = False,
|
||||
# **kwargs).`
|
||||
# The few keyword arguments are not used in the recompute module forward function, but by the
|
||||
# checkpoint function itself, so we need to filter them out otherwise module forward function
|
||||
# would complain about unexpected keyword arguments.
|
||||
all_input_parameters = inspect.signature(original_checkpoint).parameters.values()
|
||||
outside_kwarg_params = []
|
||||
for input_parameter in all_input_parameters:
|
||||
if (
|
||||
input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
or input_parameter.kind == inspect.Parameter.VAR_KEYWORD
|
||||
):
|
||||
outside_kwarg_params.append(input_parameter.name)
|
||||
|
||||
def _checkpoint(
|
||||
function,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# Conditions to activate layer-wise memory optimization automatically:
|
||||
# 1. Checkpoint is called during ORTModule model export context.
|
||||
# 2. Gradient checkpoint autograd function export is disallowed.
|
||||
# 3. Memory optimization level is layer-wise recompute.
|
||||
if (
|
||||
ORTMODULE_ONNX_EXPORT_CONTEXT[0] is True
|
||||
and _defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) != 1
|
||||
and _defined_from_envvar("ORTMODULE_MEMORY_OPT_LEVEL", 0) == 1
|
||||
):
|
||||
for name in outside_kwarg_params:
|
||||
if name in kwargs:
|
||||
# Pop out the keyword argument to avoid passing it to the module run function
|
||||
kwargs.pop(name)
|
||||
print(
|
||||
"Layer-wise memory optimization is enabled upon detecting "
|
||||
"gradient checkpointing autograd function usage during model execution."
|
||||
)
|
||||
return function(*args, **kwargs)
|
||||
return original_checkpoint(
|
||||
function,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _checkpoint
|
||||
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
from torch.utils.checkpoint import checkpoint as original_torch_checkpoint
|
||||
|
||||
torch.utils.checkpoint.checkpoint = _override_gradient_checkpoint(original_torch_checkpoint)
|
||||
|
||||
import deepspeed
|
||||
|
||||
original_deepspeed_checkpoint = deepspeed.checkpointing.checkpoint
|
||||
deepspeed.checkpointing.checkpoint = _override_gradient_checkpoint(original_deepspeed_checkpoint)
|
||||
|
||||
|
||||
################################################################################
|
||||
# All global constant goes here, before ORTModule is imported ##################
|
||||
# NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and
|
||||
|
|
@ -55,7 +140,24 @@ ORTMODULE_IS_DETERMINISTIC = torch.are_deterministic_algorithms_enabled()
|
|||
ONNXRUNTIME_CUDA_VERSION = ort_info.cuda_version if hasattr(ort_info, "cuda_version") else None
|
||||
ONNXRUNTIME_ROCM_VERSION = ort_info.rocm_version if hasattr(ort_info, "rocm_version") else None
|
||||
|
||||
# Verify minimum PyTorch version is installed before proceding to ONNX Runtime initialization
|
||||
# The first value indicates whether the code is in ONNX export context.
|
||||
# The export context here include the full export process, including prepare export input/output information,
|
||||
# and export model.
|
||||
ORTMODULE_ONNX_EXPORT_CONTEXT = [False]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def export_context():
|
||||
"""Context manager for model export."""
|
||||
try:
|
||||
ORTMODULE_ONNX_EXPORT_CONTEXT[0] = True
|
||||
|
||||
yield
|
||||
finally:
|
||||
ORTMODULE_ONNX_EXPORT_CONTEXT[0] = False
|
||||
|
||||
|
||||
# Verify minimum PyTorch version is installed before proceeding to ONNX Runtime initialization
|
||||
try:
|
||||
import torch
|
||||
|
||||
|
|
@ -70,6 +172,8 @@ try:
|
|||
f" but version {torch.__version__} was found instead."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
except ORTModuleFallbackException as e:
|
||||
# Initialization fallback is handled at ORTModule.__init__
|
||||
_FALLBACK_INIT_EXCEPTION = e
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from onnxruntime.training.utils import pytorch_scalar_type_to_pytorch_dtype, pyt
|
|||
|
||||
from ._custom_op_symbolic_registry import wrap_custom_export_function
|
||||
from ._fallback import ORTModuleONNXModelException, wrap_exception
|
||||
from ._logger import LogColor
|
||||
from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version
|
||||
|
||||
|
||||
|
|
@ -112,35 +113,6 @@ def register_custom_function_schema_supplementary(kclass: torch.autograd.Functio
|
|||
register_input_alias_function(kclass_name, kclass.alias_input)
|
||||
|
||||
|
||||
"""
|
||||
Defines a list of names of torch.autograd.Function, for checkpoint activation purposes.
|
||||
|
||||
Note:
|
||||
If CheckpointFunction is exported as PythonOp, the checkpoint-ed computation
|
||||
(applied on every N transformer layer) may be computed by PyTorch, not ORT.
|
||||
This situation should be especially noted for large language models such as GPT-2.
|
||||
|
||||
As alternatives to using checkpoint activation:
|
||||
1. Users could leverage HierarchalORTModule to wrap the model, which only wrap exportable
|
||||
sub-nn.Module's as ORTModule. In this way, ideally, ORT could cover most of the model computation,
|
||||
other than dispatching them to PyTorch.
|
||||
2. Users could disable the check by setting the environment variable ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1.
|
||||
This may imply that the exported model is not fully running on ORT, users should be aware of the potential
|
||||
performance impact.
|
||||
3. Users could also leverage ORT's memory optimization feature to achieve a similar effect as checkpointing
|
||||
activations. Turn off PyTorch's checkpointing activation, then refer to env var ORTMODULE_MEMORY_OPT_CONFIG
|
||||
to enable ORT's recomputation feature.
|
||||
|
||||
"""
|
||||
_UNSUPPORTED_CKPT_FUNC_NAMES = frozenset(
|
||||
[
|
||||
# Full qualified name.
|
||||
"torch.utils.checkpoint.CheckpointFunction",
|
||||
"deepspeed.checkpointing.CheckpointFunction",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_training_mode() -> bool:
|
||||
# TODO move to public API once the exporter team exposes that
|
||||
training_mode = None
|
||||
|
|
@ -192,15 +164,6 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
|
||||
# Fall back to common exporter if not handled by high priority exporter.
|
||||
|
||||
# Check if the checkpointing activation is allowed.
|
||||
is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1
|
||||
if is_ckpt_activation_allowed is False and func_full_qual_name in _UNSUPPORTED_CKPT_FUNC_NAMES:
|
||||
raise Exception(
|
||||
f"The torch.autograd.Function {func_full_qual_name} should not be exported to ONNX. "
|
||||
"Please replace ORTModule with HierarchalORTModule to only"
|
||||
"wrap exportable sub-nn.Module's as ORTModule."
|
||||
)
|
||||
|
||||
cconv = n.cconv()
|
||||
|
||||
input_tensor_types = []
|
||||
|
|
@ -394,6 +357,70 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model
|
|||
return exported_model
|
||||
|
||||
|
||||
@register_high_priority_handler("torch.utils.checkpoint.CheckpointFunction")
|
||||
@register_high_priority_handler("deepspeed.checkpointing.CheckpointFunction")
|
||||
def _gradient_checkpointing_export(g, n, *args, **kwargs):
|
||||
"""
|
||||
Register specialized exporter for torch.autograd.Function(s) used for checkpoint activation purposes.
|
||||
|
||||
Note:
|
||||
If CheckpointFunction is exported as PythonOp, the checkpoint-ed computation
|
||||
(applied on every N transformer layer) may be computed by PyTorch, not ORT.
|
||||
This situation should be especially noted for large language models such as GPT-2.
|
||||
|
||||
As alternatives to using checkpoint activation:
|
||||
1. Users could leverage HierarchalORTModule to wrap the model, which only wrap exportable
|
||||
sub-nn.Module's as ORTModule. In this way, ideally, ORT could cover most of the model computation,
|
||||
other than dispatching them to PyTorch.
|
||||
2. Users could disable the check by setting the environment variable ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1.
|
||||
This may imply that the exported model is not fully running on ORT, users should be aware of the potential
|
||||
performance impact.
|
||||
3. Users could also leverage ORT's memory optimization feature to achieve a similar effect as checkpointing
|
||||
activations. Turn off PyTorch's checkpointing activation, then refer to env var ORTMODULE_MEMORY_OPT_LEVEL
|
||||
to enable ORT's recomputation feature.
|
||||
|
||||
"""
|
||||
# Check if the checkpointing activation is allowed.
|
||||
is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1
|
||||
if is_ckpt_activation_allowed is False:
|
||||
is_layerwise_recompute_enabled = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_LEVEL", 0) == 1
|
||||
if not is_layerwise_recompute_enabled:
|
||||
raise Exception(
|
||||
f"{LogColor.RED}"
|
||||
"Model uses gradient checkpointing (via {func_full_qual_name}), "
|
||||
"which is not supported for export. \n"
|
||||
"Consider these alternatives:\n"
|
||||
"1) Enable ORTModule's gradient checkpointing for similar or better "
|
||||
"memory efficiency with `export ORTMODULE_MEMORY_OPT_LEVEL=1`.\n"
|
||||
"2) Allow gradient checkpointing export by setting the environment "
|
||||
"variable `ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`, though subsequent "
|
||||
"execution may fail."
|
||||
"3) Replace ORTModule with HierarchalORTModule to wrap exportable "
|
||||
"sub-nn.Module's as ORTModule.\n"
|
||||
f"{LogColor.ENDC}"
|
||||
)
|
||||
|
||||
# Hitting this branch means the user has enabled layerwise recompute, but _override_gradient_checkpoint didn't
|
||||
# catch the checkpointing function. This is usually because model code is importing torch.utils.checkpoint
|
||||
# earlier than ORTModule. We should tolerantly allow this case to happen.
|
||||
raise Exception(
|
||||
f"{LogColor.RED}"
|
||||
"Model uses gradient checkpointing (via {func_full_qual_name}), which is not "
|
||||
"supported for export. \n"
|
||||
"Consider these alternatives:\n"
|
||||
"1) `ORTMODULE_MEMORY_OPT_LEVEL=1` is set but checkpoint functions in the model "
|
||||
"are not overridden during onnxruntime.training.ortmodule import, consider importing "
|
||||
"onnxruntime.training.ortmodule earlier before any model code loaded.\n"
|
||||
"2) To allow gradient checkpointing export, set `ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`. "
|
||||
"Subsequent execution may fail.\n"
|
||||
"3) Replace ORTModule with HierarchalORTModule to wrap exportable sub-nn.Module's as "
|
||||
"ORTModule.\n"
|
||||
f"{LogColor.ENDC}"
|
||||
)
|
||||
else:
|
||||
return None # Let the common exporter handle the checkpointing function
|
||||
|
||||
|
||||
@register_high_priority_handler("bitsandbytes.autograd._functions.MatMul4Bit")
|
||||
def _matmul4bit_export(g, n, *args, **kwargs):
|
||||
cconv = n.cconv()
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from onnxruntime.capi import _pybind_state as C
|
|||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype
|
||||
|
||||
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
|
||||
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils, export_context
|
||||
from ._fallback import (
|
||||
ORTModuleDeviceException,
|
||||
ORTModuleONNXModelException,
|
||||
|
|
@ -33,6 +33,7 @@ from ._fallback import (
|
|||
from ._gradient_accumulation_manager import GradientAccumulationManager
|
||||
from ._graph_execution_interface import GraphExecutionInterface
|
||||
from ._io import _FlattenedModule, _InputInfo
|
||||
from ._logger import LogColor
|
||||
from ._runtime_inspector import RuntimeInspector
|
||||
from ._utils import check_function_has_param, get_rank
|
||||
from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions
|
||||
|
|
@ -308,7 +309,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step
|
||||
|
||||
with no_increase_global_step():
|
||||
with export_context(), no_increase_global_step():
|
||||
self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs)
|
||||
if self._debug_options.save_onnx_models.save:
|
||||
self._onnx_models.save_exported_model(
|
||||
|
|
@ -441,12 +442,49 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
**self._export_extra_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
message = _utils.get_exception_as_string(e)
|
||||
|
||||
# Special handling when Huggingface transformers gradient checkpoint usage pattern found.
|
||||
# For new versions of PyTorch 2, tracing torch.utils.checkpoint.checkpoint will be failed like this:
|
||||
# File "microsoft/phi-2/b10c3eba545ad279e7208ee3a5d644566f001670/modeling_phi.py", line 919, in forward
|
||||
# layer_outputs = self._gradient_checkpointing_func(
|
||||
# File "/site-packages/torch/_compile.py", line 24, in inner
|
||||
# return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
|
||||
# File "/site-packages/torch/_dynamo/eval_frame.py", line 470, in _fn
|
||||
# raise RuntimeError(
|
||||
# RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.
|
||||
if (
|
||||
"_gradient_checkpointing_func" in message
|
||||
and "Detected that you are using FX to torch.jit.trace a dynamo-optimized function" in message
|
||||
):
|
||||
is_ckpt_activation_allowed = int(os.getenv("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", "0")) == 1
|
||||
notes = (
|
||||
" Your model is running with gradient checkpointing, yet the PyTorch exporter\n"
|
||||
" failed during tracing the graph. Try to enable ORTModule's\n"
|
||||
" gradient checkpointing (a.k.a. Transformer layerwise subgraph recompute)\n"
|
||||
" using `export ORTMODULE_MEMORY_OPT_LEVEL=1` for similar or even better memory efficiency.\n"
|
||||
)
|
||||
if is_ckpt_activation_allowed:
|
||||
# If the user allows the gradient checkpointing export, we should inform the user to disable it,
|
||||
# to make layerwise recompute work.
|
||||
notes += (
|
||||
" We also notice your setting `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`,\n"
|
||||
" which enables gradient checkpointing torch.autograd.Functions(s) to export.\n"
|
||||
" To enable ORTModule's layerwise recompute, it needs to be turned OFF by\n"
|
||||
" `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=0`.\n"
|
||||
)
|
||||
|
||||
self._logger.error(
|
||||
f"{LogColor.RED}\n"
|
||||
"******************************** IMPORTANT NOTE *******************************\n"
|
||||
f"{notes}"
|
||||
"*******************************************************************************\n"
|
||||
f"{LogColor.ENDC}\n"
|
||||
)
|
||||
|
||||
raise wrap_exception( # noqa: B904
|
||||
ORTModuleONNXModelException,
|
||||
RuntimeError(
|
||||
f"There was an error while exporting the PyTorch model to ONNX: "
|
||||
f"\n\n{_utils.get_exception_as_string(e)}"
|
||||
),
|
||||
RuntimeError(f"There was an error while exporting the PyTorch model to ONNX: \n\n{message}"),
|
||||
)
|
||||
exported_model = onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
|
|
@ -773,17 +811,21 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
else:
|
||||
opt_config_to_display = self._runtime_options.memory_optimizer_config
|
||||
|
||||
mem_infos = ""
|
||||
if self._runtime_options.memory_optimizer_is_enabled():
|
||||
mem_infos += (
|
||||
f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], "
|
||||
f"Optimization Config: [{opt_config_to_display}]"
|
||||
)
|
||||
else:
|
||||
mem_infos = "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1/2 or ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,..."
|
||||
|
||||
mem_row = _add_record(
|
||||
tbl,
|
||||
[
|
||||
"Memory Optimizer",
|
||||
len(self._runtime_options.memory_optimizer_config) > 0,
|
||||
(
|
||||
f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], "
|
||||
f"Optimization Config: [{opt_config_to_display}]"
|
||||
if len(self._runtime_options.memory_optimizer_config) > 0
|
||||
else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1/2 or ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,..."
|
||||
),
|
||||
self._runtime_options.memory_optimizer_is_enabled(),
|
||||
mem_infos,
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -794,7 +836,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
)
|
||||
if mem_tbl is not None:
|
||||
mem_row.append_annotation_table(mem_tbl)
|
||||
notes.extend(mem_notes)
|
||||
notes.extend([f"[{mem_row._columns[0]}] {n}" for n in mem_notes])
|
||||
|
||||
compute_opt_row = _add_record(
|
||||
tbl,
|
||||
|
|
@ -819,13 +861,21 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
if len(self._runtime_options.label_sparsity_ratio) > 0:
|
||||
_add_record(
|
||||
compute_opt_annotation_tbl,
|
||||
[" - Label Sparsity Opt", True, f"Input density: {self._runtime_options.label_sparsity_ratio}"],
|
||||
[
|
||||
" - Label Sparsity",
|
||||
True,
|
||||
f"[AUTO ENABLED] Input density: {self._runtime_options.label_sparsity_ratio}",
|
||||
],
|
||||
)
|
||||
|
||||
if len(self._runtime_options.embed_sparsity_ratio) > 0:
|
||||
_add_record(
|
||||
compute_opt_annotation_tbl,
|
||||
[" - Embed Sparsity Opt", True, f"Input density: {self._runtime_options.embed_sparsity_ratio}"],
|
||||
[
|
||||
" - Embed Sparsity",
|
||||
True,
|
||||
f"[AUTO ENABLED] Input density: {self._runtime_options.embed_sparsity_ratio}",
|
||||
],
|
||||
)
|
||||
|
||||
compute_opt_row.append_annotation_table(compute_opt_annotation_tbl)
|
||||
|
|
|
|||
|
|
@ -730,16 +730,16 @@ class MemoryObserver:
|
|||
notes = []
|
||||
if details:
|
||||
notes.append(
|
||||
"[Memory Optimizer] Use ORTMODULE_MEMORY_OPT_LEVEL=1/2 to enable all recomputable subgraphs per transformer layer."
|
||||
"Use ORTMODULE_MEMORY_OPT_LEVEL=1 or 2 to enable all recomputable subgraphs per transformer layer."
|
||||
)
|
||||
saving_recommendation = (
|
||||
"Or use comma as a delimiter to selectively enable multiple memory optimization plans:\n"
|
||||
)
|
||||
saving_recommendation = "[Memory Optimizer] Or use comma as a delimiter to selectively enable multiple memory optimization plans:\n"
|
||||
saving_recommendation += " export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,..."
|
||||
|
||||
notes.append(saving_recommendation)
|
||||
|
||||
saving_recommendation = (
|
||||
"[Memory Optimizer] Memory saving is calculated based on the 1st batch symbolic dim values:\n"
|
||||
)
|
||||
saving_recommendation = "Memory saving is calculated based on the 1st batch symbolic dim values:\n"
|
||||
for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items():
|
||||
saving_recommendation += f" {dim_param}={dim_value},"
|
||||
notes.append(saving_recommendation)
|
||||
|
|
|
|||
|
|
@ -6546,6 +6546,9 @@ def test_bert_memory_inspection(caplog):
|
|||
torch.cuda.synchronize()
|
||||
if original_val is not None:
|
||||
os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = original_val
|
||||
else:
|
||||
if "ORTMODULE_PRINT_MEMORY_STATS" in os.environ:
|
||||
del os.environ["ORTMODULE_PRINT_MEMORY_STATS"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32])
|
||||
|
|
@ -6599,3 +6602,97 @@ def test_overridden_softmax_export(softmax_compute_type):
|
|||
assert to_attr.name == "to"
|
||||
to_value = to_attr.i
|
||||
assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("memory_optimization_level", [None, 0, 1, 2])
|
||||
@pytest.mark.parametrize("allow_gradient_checkpoint_export", [None, 0, 1])
|
||||
@pytest.mark.parametrize("fx", ["torch", "deepspeed"])
|
||||
def test_enable_layerwise_recompute(memory_optimization_level, allow_gradient_checkpoint_export, fx, caplog):
|
||||
"""Expected behaviors:
|
||||
memory_optimization_level=0|None, allow_gradient_checkpoint_export=0|None => layerwise recompute is disabled
|
||||
memory_optimization_level=1, allow_gradient_checkpoint_export=0|None => layerwise recompute is enabled
|
||||
memory_optimization_level=2, allow_gradient_checkpoint_export=0|None => layerwise recompute is disabled
|
||||
memory_optimization_level=0|None, allow_gradient_checkpoint_export=1 => layerwise recompute is disabled
|
||||
memory_optimization_level=1, allow_gradient_checkpoint_export=1 => layerwise recompute is disabled
|
||||
memory_optimization_level=2, allow_gradient_checkpoint_export=1 => layerwise recompute is disabled
|
||||
"""
|
||||
if fx == "deepspeed":
|
||||
try:
|
||||
import deepspeed
|
||||
|
||||
checkpoint = deepspeed.checkpointing.checkpoint
|
||||
except ImportError:
|
||||
# skip if deepspeed is not installed (in amd CI)
|
||||
return
|
||||
|
||||
elif fx == "torch":
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
else:
|
||||
raise ValueError(f"unsupported fx value: {fx}. only torch and deepspeed are supported.")
|
||||
|
||||
original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None)
|
||||
original_val_allow_gradient_checkpoint_export = os.environ.get("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", None)
|
||||
|
||||
if memory_optimization_level is not None:
|
||||
os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = str(memory_optimization_level)
|
||||
else:
|
||||
if original_val is not None:
|
||||
del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"]
|
||||
|
||||
if allow_gradient_checkpoint_export:
|
||||
os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] = str(allow_gradient_checkpoint_export)
|
||||
else:
|
||||
if original_val_allow_gradient_checkpoint_export is not None:
|
||||
del os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"]
|
||||
|
||||
class SampleModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(10, 10)
|
||||
self.layer2 = nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x):
|
||||
# Checkpointing the first layer
|
||||
x = checkpoint(self.layer1, x)
|
||||
x = nn.ReLU()(x)
|
||||
# The second layer is not checkpointed
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
model = SampleModel().cuda()
|
||||
input = torch.randn(1, 10).cuda()
|
||||
model = ORTModule(model, DebugOptions(log_level=LogLevel.INFO))
|
||||
|
||||
# Forward pass
|
||||
|
||||
# Tolerant export failure.
|
||||
import contextlib
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
_ = model(input)
|
||||
|
||||
layerwise_recompute_info_records = [
|
||||
record.message for record in caplog.records if "Layer-wise memory optimization is enabled" in record.message
|
||||
]
|
||||
|
||||
if memory_optimization_level != 1:
|
||||
assert len(layerwise_recompute_info_records) == 0
|
||||
else:
|
||||
if allow_gradient_checkpoint_export is None or allow_gradient_checkpoint_export == 0:
|
||||
assert len(layerwise_recompute_info_records) > 0
|
||||
else:
|
||||
assert len(layerwise_recompute_info_records) == 0
|
||||
|
||||
# Make sure environment variable is restored to its original value after the run is completed.
|
||||
torch.cuda.synchronize()
|
||||
if original_val is not None:
|
||||
os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val
|
||||
else:
|
||||
if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ:
|
||||
del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"]
|
||||
|
||||
if original_val_allow_gradient_checkpoint_export is not None:
|
||||
os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] = original_val_allow_gradient_checkpoint_export
|
||||
else:
|
||||
if "ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT" in os.environ:
|
||||
del os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue