mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Skip module clone for preparing large model export (#18663)
### Skip module clone for preparing large model export For LLAMA2 13B, when running with Lora, DeepSpeed stage2 on 8 GPUs . It failed during preparing outputs which will be used for torch.onnx.export. The reason, we deep copy all the params including both big sizes of frozen weights, + a little bit of Lora trainable weight. This PR will firstly check whether the GPU memmory is enough for a cloned module, if not, skip the copy. Copying the module is to guarantee the fw path run may change the weight, while this case should be rare. But for now, Not-Able-To-Run is worse than Runnable-with-A-little-bit-different-initial-weight, especially for large models.
This commit is contained in:
parent
9aa7284351
commit
4bfa84487c
4 changed files with 76 additions and 6 deletions
|
|
@ -278,6 +278,17 @@ data sparsity based performance optimizations.
|
|||
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
|
||||
```
|
||||
|
||||
#### ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT
|
||||
|
||||
- **Feature Area**: *ORTMODULE/Optimizations*
|
||||
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the module deep copy when preparing output data which will be used by ONNX export.
|
||||
A classical usage of disabling the deep copy: when the deep copy before module export bring the memory peak, then we should disable it and have a try.
|
||||
|
||||
```bash
|
||||
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1 # Enable
|
||||
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable
|
||||
```
|
||||
|
||||
### 2.2 Memory Optimization
|
||||
|
||||
Q: *Want to run a bigger batch size?*
|
||||
|
|
|
|||
|
|
@ -327,12 +327,30 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
# Setup dynamic axes for onnx model
|
||||
self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs)
|
||||
need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned(
|
||||
self._original_module, self._device
|
||||
)
|
||||
if not need_deep_copy:
|
||||
if self._runtime_options.deepcopy_before_model_export:
|
||||
self._logger.warning(
|
||||
"Since the user requested not to deep copy this model, "
|
||||
"the initial weights may not be preserved and could change slightly during the forward run. "
|
||||
"This could cause a minor difference between the ORTModule and the PyTorch run for the "
|
||||
"first iteration. The computation will proceed as normal, but this should be noted."
|
||||
)
|
||||
else:
|
||||
self._logger.warning(
|
||||
"Due to the limited GPU memory execution manager does not create a deep copy of this model. "
|
||||
"Therefore, the initial weights might be slightly altered during the forward run. "
|
||||
"This could result in a minor discrepancy between the ORTModule and the PyTorch run for the "
|
||||
"first iteration. The computation will continue as usual, but this should be noted."
|
||||
)
|
||||
(
|
||||
output_names,
|
||||
output_dynamic_axes,
|
||||
self._module_output_schema,
|
||||
) = _io.parse_outputs_for_onnx_export_and_extract_schema(
|
||||
self._original_module, inputs, kwargs, self._logger, self._device
|
||||
self._original_module, inputs, kwargs, self._logger, self._device, need_deep_copy
|
||||
)
|
||||
self._input_info.dynamic_axes.update(output_dynamic_axes)
|
||||
|
||||
|
|
|
|||
|
|
@ -543,25 +543,61 @@ def parse_inputs_for_onnx_export(
|
|||
)
|
||||
|
||||
|
||||
def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int:
|
||||
"""Calculate the total parameter size in bytes"""
|
||||
total_size = 0
|
||||
for p in module.parameters():
|
||||
total_size += p.numel() * p.element_size()
|
||||
return total_size
|
||||
|
||||
|
||||
def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.device]) -> bool:
|
||||
"""Check if the module can be cloned
|
||||
|
||||
If the 2 times total module parameter size >= device memory, the module cannot be cloned.
|
||||
> Initially there is one set of parameters;
|
||||
> parse_outputs_for_onnx_export_and_extract_schema want to clone the full module including the frozen weight;
|
||||
> PyTorch ONNX exporter will clone the trainable parameters;
|
||||
|
||||
So as long as the module can be cloned in parse_outputs_for_onnx_export_and_extract_schema, it is safe
|
||||
to export the model without OOM. Here we return whether can clone the module in
|
||||
parse_outputs_for_onnx_export_and_extract_schema.
|
||||
|
||||
Args:
|
||||
module: The module to be cloned.
|
||||
device: The device to be used for cloning.
|
||||
"""
|
||||
|
||||
if device is None or device.type != "cuda":
|
||||
return True
|
||||
|
||||
total_size = calculate_total_parameter_size_in_bytes(module)
|
||||
return total_size * 2 < torch.cuda.get_device_properties(device).total_memory * 0.90 # give a 10% buffer
|
||||
|
||||
|
||||
def parse_outputs_for_onnx_export_and_extract_schema(
|
||||
module,
|
||||
args: Sequence[ORTModelInputOutputType],
|
||||
kwargs: Mapping[str, ORTModelInputOutputType],
|
||||
logger: Logger,
|
||||
device: Optional[torch.device],
|
||||
clone_module: bool,
|
||||
):
|
||||
# Perform a forward call to grab outputs
|
||||
output_names = None
|
||||
output_dynamic_axes = None
|
||||
is_deepcopy = False
|
||||
deep_copied = False
|
||||
logger.info("Running model forward to infer output schema and dynamic axes...")
|
||||
with torch.no_grad():
|
||||
# Deepcopy inputs, since input values may change after model run.
|
||||
sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs)
|
||||
try:
|
||||
# Deepcopy model, in case model is stateful and changes after model run.
|
||||
model_copy = copy.deepcopy(module)
|
||||
is_deepcopy = True
|
||||
if clone_module:
|
||||
# Deepcopy model, in case model is stateful and changes after model run.
|
||||
model_copy = copy.deepcopy(module)
|
||||
deep_copied = True
|
||||
else:
|
||||
model_copy = module
|
||||
except Exception:
|
||||
model_copy = module
|
||||
logger.warning(
|
||||
|
|
@ -576,7 +612,7 @@ def parse_outputs_for_onnx_export_and_extract_schema(
|
|||
output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs)
|
||||
|
||||
output_schema = _extract_schema(sample_outputs, device)
|
||||
if is_deepcopy:
|
||||
if deep_copied:
|
||||
del model_copy
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
|
|
|
|||
|
|
@ -286,6 +286,8 @@ class _RuntimeOptions:
|
|||
# Experimental features.
|
||||
self.enable_zero_stage3_support = False # Once enabled, cannot be disabled.
|
||||
|
||||
self.deepcopy_before_model_export = True
|
||||
|
||||
# Override the feature config if it exists in os env.
|
||||
self._override_from_env_vars()
|
||||
|
||||
|
|
@ -367,3 +369,6 @@ class _RuntimeOptions:
|
|||
# Experimental features.
|
||||
if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1:
|
||||
self.enable_zero_stage3_support = True
|
||||
|
||||
if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ:
|
||||
self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue