Skip module clone for preparing large model export (#18663)

### Skip module clone for preparing large model export

For LLAMA2 13B, when running with Lora, DeepSpeed stage2 on 8 GPUs . It
failed during preparing outputs which will be used for
torch.onnx.export. The reason, we deep copy all the params including
both big sizes of frozen weights, + a little bit of Lora trainable
weight.

This PR will firstly check whether the GPU memmory is enough for a
cloned module, if not, skip the copy.

Copying the module is to guarantee the fw path run may change the
weight, while this case should be rare. But for now, Not-Able-To-Run is
worse than Runnable-with-A-little-bit-different-initial-weight,
especially for large models.
This commit is contained in:
pengwa 2023-12-06 04:41:17 +08:00 committed by GitHub
parent 9aa7284351
commit 4bfa84487c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 6 deletions

View file

@ -278,6 +278,17 @@ data sparsity based performance optimizations.
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
```
#### ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the module deep copy when preparing output data which will be used by ONNX export.
A classical usage of disabling the deep copy: when the deep copy before module export bring the memory peak, then we should disable it and have a try.
```bash
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1 # Enable
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable
```
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*

View file

@ -327,12 +327,30 @@ class GraphExecutionManager(GraphExecutionInterface):
# Setup dynamic axes for onnx model
self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs)
need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned(
self._original_module, self._device
)
if not need_deep_copy:
if self._runtime_options.deepcopy_before_model_export:
self._logger.warning(
"Since the user requested not to deep copy this model, "
"the initial weights may not be preserved and could change slightly during the forward run. "
"This could cause a minor difference between the ORTModule and the PyTorch run for the "
"first iteration. The computation will proceed as normal, but this should be noted."
)
else:
self._logger.warning(
"Due to the limited GPU memory execution manager does not create a deep copy of this model. "
"Therefore, the initial weights might be slightly altered during the forward run. "
"This could result in a minor discrepancy between the ORTModule and the PyTorch run for the "
"first iteration. The computation will continue as usual, but this should be noted."
)
(
output_names,
output_dynamic_axes,
self._module_output_schema,
) = _io.parse_outputs_for_onnx_export_and_extract_schema(
self._original_module, inputs, kwargs, self._logger, self._device
self._original_module, inputs, kwargs, self._logger, self._device, need_deep_copy
)
self._input_info.dynamic_axes.update(output_dynamic_axes)

View file

@ -543,25 +543,61 @@ def parse_inputs_for_onnx_export(
)
def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int:
"""Calculate the total parameter size in bytes"""
total_size = 0
for p in module.parameters():
total_size += p.numel() * p.element_size()
return total_size
def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.device]) -> bool:
"""Check if the module can be cloned
If the 2 times total module parameter size >= device memory, the module cannot be cloned.
> Initially there is one set of parameters;
> parse_outputs_for_onnx_export_and_extract_schema want to clone the full module including the frozen weight;
> PyTorch ONNX exporter will clone the trainable parameters;
So as long as the module can be cloned in parse_outputs_for_onnx_export_and_extract_schema, it is safe
to export the model without OOM. Here we return whether can clone the module in
parse_outputs_for_onnx_export_and_extract_schema.
Args:
module: The module to be cloned.
device: The device to be used for cloning.
"""
if device is None or device.type != "cuda":
return True
total_size = calculate_total_parameter_size_in_bytes(module)
return total_size * 2 < torch.cuda.get_device_properties(device).total_memory * 0.90 # give a 10% buffer
def parse_outputs_for_onnx_export_and_extract_schema(
module,
args: Sequence[ORTModelInputOutputType],
kwargs: Mapping[str, ORTModelInputOutputType],
logger: Logger,
device: Optional[torch.device],
clone_module: bool,
):
# Perform a forward call to grab outputs
output_names = None
output_dynamic_axes = None
is_deepcopy = False
deep_copied = False
logger.info("Running model forward to infer output schema and dynamic axes...")
with torch.no_grad():
# Deepcopy inputs, since input values may change after model run.
sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs)
try:
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(module)
is_deepcopy = True
if clone_module:
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(module)
deep_copied = True
else:
model_copy = module
except Exception:
model_copy = module
logger.warning(
@ -576,7 +612,7 @@ def parse_outputs_for_onnx_export_and_extract_schema(
output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs)
output_schema = _extract_schema(sample_outputs, device)
if is_deepcopy:
if deep_copied:
del model_copy
gc.collect()
if torch.cuda.is_available():

View file

@ -286,6 +286,8 @@ class _RuntimeOptions:
# Experimental features.
self.enable_zero_stage3_support = False # Once enabled, cannot be disabled.
self.deepcopy_before_model_export = True
# Override the feature config if it exists in os env.
self._override_from_env_vars()
@ -367,3 +369,6 @@ class _RuntimeOptions:
# Experimental features.
if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1:
self.enable_zero_stage3_support = True
if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ:
self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1