diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index eaa48c9da0..d08ba7b8f8 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -30,10 +30,10 @@ Integrate models using `ORTModule`. ``` There are two modes to enable the memory optimizations: -- Aggressively Recompute All Within Each Transformer Layer, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. This will recompute all detected subgraphs within each Transformer Attention+MLP layer. It is easy to enable, but be noted this recompute plan may NOT be the best one. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected. -- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=,,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans. +- Transformer layerwise recompute, e.g. aggressively recompute all supported nodes within each transformer layer (usually including attention and mlp sublayers), enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected. +- Manual selected subgraph recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=,,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans. -### Mode 1 - Simple Usage (Aggressively Recompute All Within Each Transformer Layer) +### Mode 1 - Simple Usage (Transformer Layerwise Recompute) 1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1` diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 4a03465cf2..20e3493395 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import contextlib +import inspect import os import sys import warnings @@ -21,7 +23,10 @@ if not is_ortmodule_available(): raise ImportError("ORTModule is not supported on this platform.") -def _defined_from_envvar(name, default_value, warn=True): +def _defined_from_envvar(name: str, default_value: any, warn: bool = True): + """Check given name exists in the environment variable and return the value using the default_value's + type if it exists. + """ new_value = os.getenv(name, None) if new_value is None: return default_value @@ -34,6 +39,86 @@ def _defined_from_envvar(name, default_value, warn=True): return new_value +def _override_gradient_checkpoint(original_checkpoint): + """ + Best effort to override `torch.utils.checkpoint` and `deepspeed.checkpointing.checkpoint` during ONNX export. + + Despite importing `torch.utils.checkpoint` or `deepspeed.checkpointing.checkpoint` in `__init__.py`, + users might import it first, causing our override to not take effect. We still attempt to override + it to work in most cases. + + We replace the checkpoint function with our implementation, without condition checks. + The actual check is in the overridden function, verifying if: + 1) `checkpoint` is called during ORTModule model export, + 2) Gradient checkpoint autograd function is disallowed (ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT), + 3) Memory optimization level is not specified by the user (ORTMODULE_MEMORY_OPT_LEVEL). + If true, we reset memory optimization to layer-wise recompute. + + """ + + # Note: The `torch.utils.checkpoint` checkpoint function signature looks like below: + # `checkpoint(function, *args, + # use_reentrant = None, + # context_fn = noop_context_fn, + # determinism_check = _DEFAULT_DETERMINISM_MODE, + # debug = False, + # **kwargs).` + # The few keyword arguments are not used in the recompute module forward function, but by the + # checkpoint function itself, so we need to filter them out otherwise module forward function + # would complain about unexpected keyword arguments. + all_input_parameters = inspect.signature(original_checkpoint).parameters.values() + outside_kwarg_params = [] + for input_parameter in all_input_parameters: + if ( + input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY + or input_parameter.kind == inspect.Parameter.VAR_KEYWORD + ): + outside_kwarg_params.append(input_parameter.name) + + def _checkpoint( + function, + *args, + **kwargs, + ): + # Conditions to activate layer-wise memory optimization automatically: + # 1. Checkpoint is called during ORTModule model export context. + # 2. Gradient checkpoint autograd function export is disallowed. + # 3. Memory optimization level is layer-wise recompute. + if ( + ORTMODULE_ONNX_EXPORT_CONTEXT[0] is True + and _defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) != 1 + and _defined_from_envvar("ORTMODULE_MEMORY_OPT_LEVEL", 0) == 1 + ): + for name in outside_kwarg_params: + if name in kwargs: + # Pop out the keyword argument to avoid passing it to the module run function + kwargs.pop(name) + print( + "Layer-wise memory optimization is enabled upon detecting " + "gradient checkpointing autograd function usage during model execution." + ) + return function(*args, **kwargs) + return original_checkpoint( + function, + *args, + **kwargs, + ) + + return _checkpoint + + +with contextlib.suppress(Exception): + from torch.utils.checkpoint import checkpoint as original_torch_checkpoint + + torch.utils.checkpoint.checkpoint = _override_gradient_checkpoint(original_torch_checkpoint) + + import deepspeed + + original_deepspeed_checkpoint = deepspeed.checkpointing.checkpoint + deepspeed.checkpointing.checkpoint = _override_gradient_checkpoint(original_deepspeed_checkpoint) + + ################################################################################ # All global constant goes here, before ORTModule is imported ################## # NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and @@ -55,7 +140,24 @@ ORTMODULE_IS_DETERMINISTIC = torch.are_deterministic_algorithms_enabled() ONNXRUNTIME_CUDA_VERSION = ort_info.cuda_version if hasattr(ort_info, "cuda_version") else None ONNXRUNTIME_ROCM_VERSION = ort_info.rocm_version if hasattr(ort_info, "rocm_version") else None -# Verify minimum PyTorch version is installed before proceding to ONNX Runtime initialization +# The first value indicates whether the code is in ONNX export context. +# The export context here include the full export process, including prepare export input/output information, +# and export model. +ORTMODULE_ONNX_EXPORT_CONTEXT = [False] + + +@contextlib.contextmanager +def export_context(): + """Context manager for model export.""" + try: + ORTMODULE_ONNX_EXPORT_CONTEXT[0] = True + + yield + finally: + ORTMODULE_ONNX_EXPORT_CONTEXT[0] = False + + +# Verify minimum PyTorch version is installed before proceeding to ONNX Runtime initialization try: import torch @@ -70,6 +172,8 @@ try: f" but version {torch.__version__} was found instead." ), ) + + except ORTModuleFallbackException as e: # Initialization fallback is handled at ORTModule.__init__ _FALLBACK_INIT_EXCEPTION = e diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index af5f3c9ceb..8a89062500 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -25,6 +25,7 @@ from onnxruntime.training.utils import pytorch_scalar_type_to_pytorch_dtype, pyt from ._custom_op_symbolic_registry import wrap_custom_export_function from ._fallback import ORTModuleONNXModelException, wrap_exception +from ._logger import LogColor from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version @@ -112,35 +113,6 @@ def register_custom_function_schema_supplementary(kclass: torch.autograd.Functio register_input_alias_function(kclass_name, kclass.alias_input) -""" -Defines a list of names of torch.autograd.Function, for checkpoint activation purposes. - -Note: - If CheckpointFunction is exported as PythonOp, the checkpoint-ed computation - (applied on every N transformer layer) may be computed by PyTorch, not ORT. - This situation should be especially noted for large language models such as GPT-2. - -As alternatives to using checkpoint activation: -1. Users could leverage HierarchalORTModule to wrap the model, which only wrap exportable -sub-nn.Module's as ORTModule. In this way, ideally, ORT could cover most of the model computation, -other than dispatching them to PyTorch. -2. Users could disable the check by setting the environment variable ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1. -This may imply that the exported model is not fully running on ORT, users should be aware of the potential -performance impact. -3. Users could also leverage ORT's memory optimization feature to achieve a similar effect as checkpointing -activations. Turn off PyTorch's checkpointing activation, then refer to env var ORTMODULE_MEMORY_OPT_CONFIG -to enable ORT's recomputation feature. - -""" -_UNSUPPORTED_CKPT_FUNC_NAMES = frozenset( - [ - # Full qualified name. - "torch.utils.checkpoint.CheckpointFunction", - "deepspeed.checkpointing.CheckpointFunction", - ] -) - - def _get_training_mode() -> bool: # TODO move to public API once the exporter team exposes that training_mode = None @@ -192,15 +164,6 @@ def _export_pt_1_10(g, n, *args, **kwargs): # Fall back to common exporter if not handled by high priority exporter. - # Check if the checkpointing activation is allowed. - is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1 - if is_ckpt_activation_allowed is False and func_full_qual_name in _UNSUPPORTED_CKPT_FUNC_NAMES: - raise Exception( - f"The torch.autograd.Function {func_full_qual_name} should not be exported to ONNX. " - "Please replace ORTModule with HierarchalORTModule to only" - "wrap exportable sub-nn.Module's as ORTModule." - ) - cconv = n.cconv() input_tensor_types = [] @@ -394,6 +357,70 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model return exported_model +@register_high_priority_handler("torch.utils.checkpoint.CheckpointFunction") +@register_high_priority_handler("deepspeed.checkpointing.CheckpointFunction") +def _gradient_checkpointing_export(g, n, *args, **kwargs): + """ + Register specialized exporter for torch.autograd.Function(s) used for checkpoint activation purposes. + + Note: + If CheckpointFunction is exported as PythonOp, the checkpoint-ed computation + (applied on every N transformer layer) may be computed by PyTorch, not ORT. + This situation should be especially noted for large language models such as GPT-2. + + As alternatives to using checkpoint activation: + 1. Users could leverage HierarchalORTModule to wrap the model, which only wrap exportable + sub-nn.Module's as ORTModule. In this way, ideally, ORT could cover most of the model computation, + other than dispatching them to PyTorch. + 2. Users could disable the check by setting the environment variable ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1. + This may imply that the exported model is not fully running on ORT, users should be aware of the potential + performance impact. + 3. Users could also leverage ORT's memory optimization feature to achieve a similar effect as checkpointing + activations. Turn off PyTorch's checkpointing activation, then refer to env var ORTMODULE_MEMORY_OPT_LEVEL + to enable ORT's recomputation feature. + + """ + # Check if the checkpointing activation is allowed. + is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1 + if is_ckpt_activation_allowed is False: + is_layerwise_recompute_enabled = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_LEVEL", 0) == 1 + if not is_layerwise_recompute_enabled: + raise Exception( + f"{LogColor.RED}" + "Model uses gradient checkpointing (via {func_full_qual_name}), " + "which is not supported for export. \n" + "Consider these alternatives:\n" + "1) Enable ORTModule's gradient checkpointing for similar or better " + "memory efficiency with `export ORTMODULE_MEMORY_OPT_LEVEL=1`.\n" + "2) Allow gradient checkpointing export by setting the environment " + "variable `ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`, though subsequent " + "execution may fail." + "3) Replace ORTModule with HierarchalORTModule to wrap exportable " + "sub-nn.Module's as ORTModule.\n" + f"{LogColor.ENDC}" + ) + + # Hitting this branch means the user has enabled layerwise recompute, but _override_gradient_checkpoint didn't + # catch the checkpointing function. This is usually because model code is importing torch.utils.checkpoint + # earlier than ORTModule. We should tolerantly allow this case to happen. + raise Exception( + f"{LogColor.RED}" + "Model uses gradient checkpointing (via {func_full_qual_name}), which is not " + "supported for export. \n" + "Consider these alternatives:\n" + "1) `ORTMODULE_MEMORY_OPT_LEVEL=1` is set but checkpoint functions in the model " + "are not overridden during onnxruntime.training.ortmodule import, consider importing " + "onnxruntime.training.ortmodule earlier before any model code loaded.\n" + "2) To allow gradient checkpointing export, set `ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`. " + "Subsequent execution may fail.\n" + "3) Replace ORTModule with HierarchalORTModule to wrap exportable sub-nn.Module's as " + "ORTModule.\n" + f"{LogColor.ENDC}" + ) + else: + return None # Let the common exporter handle the checkpointing function + + @register_high_priority_handler("bitsandbytes.autograd._functions.MatMul4Bit") def _matmul4bit_export(g, n, *args, **kwargs): cconv = n.cconv() diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5123594bff..b272c629b1 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -21,7 +21,7 @@ from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype -from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils +from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils, export_context from ._fallback import ( ORTModuleDeviceException, ORTModuleONNXModelException, @@ -33,6 +33,7 @@ from ._fallback import ( from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface from ._io import _FlattenedModule, _InputInfo +from ._logger import LogColor from ._runtime_inspector import RuntimeInspector from ._utils import check_function_has_param, get_rank from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions @@ -308,7 +309,7 @@ class GraphExecutionManager(GraphExecutionInterface): from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step - with no_increase_global_step(): + with export_context(), no_increase_global_step(): self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs) if self._debug_options.save_onnx_models.save: self._onnx_models.save_exported_model( @@ -441,12 +442,49 @@ class GraphExecutionManager(GraphExecutionInterface): **self._export_extra_kwargs, ) except Exception as e: + message = _utils.get_exception_as_string(e) + + # Special handling when Huggingface transformers gradient checkpoint usage pattern found. + # For new versions of PyTorch 2, tracing torch.utils.checkpoint.checkpoint will be failed like this: + # File "microsoft/phi-2/b10c3eba545ad279e7208ee3a5d644566f001670/modeling_phi.py", line 919, in forward + # layer_outputs = self._gradient_checkpointing_func( + # File "/site-packages/torch/_compile.py", line 24, in inner + # return torch._dynamo.disable(fn, recursive)(*args, **kwargs) + # File "/site-packages/torch/_dynamo/eval_frame.py", line 470, in _fn + # raise RuntimeError( + # RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment. + if ( + "_gradient_checkpointing_func" in message + and "Detected that you are using FX to torch.jit.trace a dynamo-optimized function" in message + ): + is_ckpt_activation_allowed = int(os.getenv("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", "0")) == 1 + notes = ( + " Your model is running with gradient checkpointing, yet the PyTorch exporter\n" + " failed during tracing the graph. Try to enable ORTModule's\n" + " gradient checkpointing (a.k.a. Transformer layerwise subgraph recompute)\n" + " using `export ORTMODULE_MEMORY_OPT_LEVEL=1` for similar or even better memory efficiency.\n" + ) + if is_ckpt_activation_allowed: + # If the user allows the gradient checkpointing export, we should inform the user to disable it, + # to make layerwise recompute work. + notes += ( + " We also notice your setting `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`,\n" + " which enables gradient checkpointing torch.autograd.Functions(s) to export.\n" + " To enable ORTModule's layerwise recompute, it needs to be turned OFF by\n" + " `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=0`.\n" + ) + + self._logger.error( + f"{LogColor.RED}\n" + "******************************** IMPORTANT NOTE *******************************\n" + f"{notes}" + "*******************************************************************************\n" + f"{LogColor.ENDC}\n" + ) + raise wrap_exception( # noqa: B904 ORTModuleONNXModelException, - RuntimeError( - f"There was an error while exporting the PyTorch model to ONNX: " - f"\n\n{_utils.get_exception_as_string(e)}" - ), + RuntimeError(f"There was an error while exporting the PyTorch model to ONNX: \n\n{message}"), ) exported_model = onnx.load_model_from_string(f.getvalue()) @@ -773,17 +811,21 @@ class GraphExecutionManager(GraphExecutionInterface): else: opt_config_to_display = self._runtime_options.memory_optimizer_config + mem_infos = "" + if self._runtime_options.memory_optimizer_is_enabled(): + mem_infos += ( + f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], " + f"Optimization Config: [{opt_config_to_display}]" + ) + else: + mem_infos = "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1/2 or ORTMODULE_MEMORY_OPT_CONFIG=,,..." + mem_row = _add_record( tbl, [ "Memory Optimizer", - len(self._runtime_options.memory_optimizer_config) > 0, - ( - f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], " - f"Optimization Config: [{opt_config_to_display}]" - if len(self._runtime_options.memory_optimizer_config) > 0 - else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1/2 or ORTMODULE_MEMORY_OPT_CONFIG=,,..." - ), + self._runtime_options.memory_optimizer_is_enabled(), + mem_infos, ], ) @@ -794,7 +836,7 @@ class GraphExecutionManager(GraphExecutionInterface): ) if mem_tbl is not None: mem_row.append_annotation_table(mem_tbl) - notes.extend(mem_notes) + notes.extend([f"[{mem_row._columns[0]}] {n}" for n in mem_notes]) compute_opt_row = _add_record( tbl, @@ -819,13 +861,21 @@ class GraphExecutionManager(GraphExecutionInterface): if len(self._runtime_options.label_sparsity_ratio) > 0: _add_record( compute_opt_annotation_tbl, - [" - Label Sparsity Opt", True, f"Input density: {self._runtime_options.label_sparsity_ratio}"], + [ + " - Label Sparsity", + True, + f"[AUTO ENABLED] Input density: {self._runtime_options.label_sparsity_ratio}", + ], ) if len(self._runtime_options.embed_sparsity_ratio) > 0: _add_record( compute_opt_annotation_tbl, - [" - Embed Sparsity Opt", True, f"Input density: {self._runtime_options.embed_sparsity_ratio}"], + [ + " - Embed Sparsity", + True, + f"[AUTO ENABLED] Input density: {self._runtime_options.embed_sparsity_ratio}", + ], ) compute_opt_row.append_annotation_table(compute_opt_annotation_tbl) diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 5c86070430..bc36a176bf 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -730,16 +730,16 @@ class MemoryObserver: notes = [] if details: notes.append( - "[Memory Optimizer] Use ORTMODULE_MEMORY_OPT_LEVEL=1/2 to enable all recomputable subgraphs per transformer layer." + "Use ORTMODULE_MEMORY_OPT_LEVEL=1 or 2 to enable all recomputable subgraphs per transformer layer." + ) + saving_recommendation = ( + "Or use comma as a delimiter to selectively enable multiple memory optimization plans:\n" ) - saving_recommendation = "[Memory Optimizer] Or use comma as a delimiter to selectively enable multiple memory optimization plans:\n" saving_recommendation += " export ORTMODULE_MEMORY_OPT_CONFIG=,,..." notes.append(saving_recommendation) - saving_recommendation = ( - "[Memory Optimizer] Memory saving is calculated based on the 1st batch symbolic dim values:\n" - ) + saving_recommendation = "Memory saving is calculated based on the 1st batch symbolic dim values:\n" for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): saving_recommendation += f" {dim_param}={dim_value}," notes.append(saving_recommendation) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 5078058995..e231579887 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6546,6 +6546,9 @@ def test_bert_memory_inspection(caplog): torch.cuda.synchronize() if original_val is not None: os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = original_val + else: + if "ORTMODULE_PRINT_MEMORY_STATS" in os.environ: + del os.environ["ORTMODULE_PRINT_MEMORY_STATS"] @pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32]) @@ -6599,3 +6602,97 @@ def test_overridden_softmax_export(softmax_compute_type): assert to_attr.name == "to" to_value = to_attr.i assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected" + + +@pytest.mark.parametrize("memory_optimization_level", [None, 0, 1, 2]) +@pytest.mark.parametrize("allow_gradient_checkpoint_export", [None, 0, 1]) +@pytest.mark.parametrize("fx", ["torch", "deepspeed"]) +def test_enable_layerwise_recompute(memory_optimization_level, allow_gradient_checkpoint_export, fx, caplog): + """Expected behaviors: + memory_optimization_level=0|None, allow_gradient_checkpoint_export=0|None => layerwise recompute is disabled + memory_optimization_level=1, allow_gradient_checkpoint_export=0|None => layerwise recompute is enabled + memory_optimization_level=2, allow_gradient_checkpoint_export=0|None => layerwise recompute is disabled + memory_optimization_level=0|None, allow_gradient_checkpoint_export=1 => layerwise recompute is disabled + memory_optimization_level=1, allow_gradient_checkpoint_export=1 => layerwise recompute is disabled + memory_optimization_level=2, allow_gradient_checkpoint_export=1 => layerwise recompute is disabled + """ + if fx == "deepspeed": + try: + import deepspeed + + checkpoint = deepspeed.checkpointing.checkpoint + except ImportError: + # skip if deepspeed is not installed (in amd CI) + return + + elif fx == "torch": + from torch.utils.checkpoint import checkpoint + else: + raise ValueError(f"unsupported fx value: {fx}. only torch and deepspeed are supported.") + + original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None) + original_val_allow_gradient_checkpoint_export = os.environ.get("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", None) + + if memory_optimization_level is not None: + os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = str(memory_optimization_level) + else: + if original_val is not None: + del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] + + if allow_gradient_checkpoint_export: + os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] = str(allow_gradient_checkpoint_export) + else: + if original_val_allow_gradient_checkpoint_export is not None: + del os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] + + class SampleModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(10, 10) + self.layer2 = nn.Linear(10, 10) + + def forward(self, x): + # Checkpointing the first layer + x = checkpoint(self.layer1, x) + x = nn.ReLU()(x) + # The second layer is not checkpointed + x = self.layer2(x) + return x + + model = SampleModel().cuda() + input = torch.randn(1, 10).cuda() + model = ORTModule(model, DebugOptions(log_level=LogLevel.INFO)) + + # Forward pass + + # Tolerant export failure. + import contextlib + + with contextlib.suppress(Exception): + _ = model(input) + + layerwise_recompute_info_records = [ + record.message for record in caplog.records if "Layer-wise memory optimization is enabled" in record.message + ] + + if memory_optimization_level != 1: + assert len(layerwise_recompute_info_records) == 0 + else: + if allow_gradient_checkpoint_export is None or allow_gradient_checkpoint_export == 0: + assert len(layerwise_recompute_info_records) > 0 + else: + assert len(layerwise_recompute_info_records) == 0 + + # Make sure environment variable is restored to its original value after the run is completed. + torch.cuda.synchronize() + if original_val is not None: + os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val + else: + if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ: + del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] + + if original_val_allow_gradient_checkpoint_export is not None: + os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] = original_val_allow_gradient_checkpoint_export + else: + if "ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT" in os.environ: + del os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"]