diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index d3ec61e867..a3cceb441a 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -278,6 +278,17 @@ data sparsity based performance optimizations. export ORTMODULE_USE_EFFICIENT_ATTENTION=1 ``` +#### ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the module deep copy when preparing output data which will be used by ONNX export. +A classical usage of disabling the deep copy: when the deep copy before module export bring the memory peak, then we should disable it and have a try. + + ```bash + export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1 # Enable + export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5696bfead7..dd6d5a568c 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -327,12 +327,30 @@ class GraphExecutionManager(GraphExecutionInterface): # Setup dynamic axes for onnx model self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) + need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned( + self._original_module, self._device + ) + if not need_deep_copy: + if self._runtime_options.deepcopy_before_model_export: + self._logger.warning( + "Since the user requested not to deep copy this model, " + "the initial weights may not be preserved and could change slightly during the forward run. " + "This could cause a minor difference between the ORTModule and the PyTorch run for the " + "first iteration. The computation will proceed as normal, but this should be noted." + ) + else: + self._logger.warning( + "Due to the limited GPU memory execution manager does not create a deep copy of this model. " + "Therefore, the initial weights might be slightly altered during the forward run. " + "This could result in a minor discrepancy between the ORTModule and the PyTorch run for the " + "first iteration. The computation will continue as usual, but this should be noted." + ) ( output_names, output_dynamic_axes, self._module_output_schema, ) = _io.parse_outputs_for_onnx_export_and_extract_schema( - self._original_module, inputs, kwargs, self._logger, self._device + self._original_module, inputs, kwargs, self._logger, self._device, need_deep_copy ) self._input_info.dynamic_axes.update(output_dynamic_axes) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index f5fbd5093f..7534cc46a2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -543,25 +543,61 @@ def parse_inputs_for_onnx_export( ) +def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int: + """Calculate the total parameter size in bytes""" + total_size = 0 + for p in module.parameters(): + total_size += p.numel() * p.element_size() + return total_size + + +def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.device]) -> bool: + """Check if the module can be cloned + + If the 2 times total module parameter size >= device memory, the module cannot be cloned. + > Initially there is one set of parameters; + > parse_outputs_for_onnx_export_and_extract_schema want to clone the full module including the frozen weight; + > PyTorch ONNX exporter will clone the trainable parameters; + + So as long as the module can be cloned in parse_outputs_for_onnx_export_and_extract_schema, it is safe + to export the model without OOM. Here we return whether can clone the module in + parse_outputs_for_onnx_export_and_extract_schema. + + Args: + module: The module to be cloned. + device: The device to be used for cloning. + """ + + if device is None or device.type != "cuda": + return True + + total_size = calculate_total_parameter_size_in_bytes(module) + return total_size * 2 < torch.cuda.get_device_properties(device).total_memory * 0.90 # give a 10% buffer + + def parse_outputs_for_onnx_export_and_extract_schema( module, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], logger: Logger, device: Optional[torch.device], + clone_module: bool, ): # Perform a forward call to grab outputs output_names = None output_dynamic_axes = None - is_deepcopy = False + deep_copied = False logger.info("Running model forward to infer output schema and dynamic axes...") with torch.no_grad(): # Deepcopy inputs, since input values may change after model run. sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs) try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(module) - is_deepcopy = True + if clone_module: + # Deepcopy model, in case model is stateful and changes after model run. + model_copy = copy.deepcopy(module) + deep_copied = True + else: + model_copy = module except Exception: model_copy = module logger.warning( @@ -576,7 +612,7 @@ def parse_outputs_for_onnx_export_and_extract_schema( output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs) output_schema = _extract_schema(sample_outputs, device) - if is_deepcopy: + if deep_copied: del model_copy gc.collect() if torch.cuda.is_available(): diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 77022f86d3..ffa3f4afa7 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -286,6 +286,8 @@ class _RuntimeOptions: # Experimental features. self.enable_zero_stage3_support = False # Once enabled, cannot be disabled. + self.deepcopy_before_model_export = True + # Override the feature config if it exists in os env. self._override_from_env_vars() @@ -367,3 +369,6 @@ class _RuntimeOptions: # Experimental features. if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1: self.enable_zero_stage3_support = True + + if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ: + self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1