Introduce ZeROOffloadSubscriber for ORTModule (#17006)

### Introduce ZeROOffloadSubscriber for ORTModule

As part of the work: integrate ORTModule with DeepSpeed stage3, this PR
mainly focus on moving original PyTorch-based (leveraging hooks) param
partition/offload implementation to ORTModule compatible implementation.

Changes include:
1. Refactor `SubscriberBase`/`SubcriberManager` to support
pre-forward/post_forward hooks.
2. Implement new `ZeROOffloadSubscriber` by re-using DeepSpeed hook
function as much as possible. Since all hook functions are defined in
`DeepSpeedZeRoOffload._register_hooks_recursively` and
`DeepSpeedZeRoOffload.setup_zero_stage3_hooks`, and the good thing is,
the closure is not complex, all hooks are referencing the owning
`DeepSpeedZeRoOffload` instance, so we can create new hook function with
`FunctionType` by binding the owning `DeepSpeedZeRoOffload` instance,
then call the new created function in subscriber's
`pre_forward_module_apply_impl` and `post_forward_module_apply_impl`
interfaces.
3. Monkey patch `DeepSpeedZeRoOffload.setup_zero_stage3_hooks` to
register the `ZeROOffloadSubscriber` for the model, then we don't need
change any code on the DeepSpeed repo (at least so far).
4. Fix the ATen embedding custom symbolic exporter function by
tolerating weights size be (0) (changed by DeepSpeed zero stage 3).

UT will be added once stage3 is fully supported. 

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2023-08-25 00:15:22 +08:00 committed by GitHub
parent fca81cc5d5
commit d90afc697b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1033 additions and 315 deletions

View file

@ -83,12 +83,13 @@ Arguments:
inspector node affects memory peak causing the original recipe run to fail with OOM.
- `bucket_size`: the size of the bucket to split the statistic calculation.
### 2.2 Use `_InspectActivation` to collect intermediate tensors in a `nn.Module` forward()
### 2.2 Use `inspect_activation` to collect intermediate tensors in a `nn.Module` forward()
The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output tensors will be dumped, if you want to
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:
```diff
+ from onnxruntime.training.utils import inspect_activation
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config: BloomConfig):
...
@ -98,10 +99,10 @@ class BloomForCausalLM(BloomPreTrainedModel):
transformer_outputs = self.transformer(...)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
+ lm_logits = _InspectActivation.apply("lm_logits", None, GlobalSubscriberManager.get_run_context(), lm_logits)
+ lm_logits = inspect_activation("lm_logits", lm_logits)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_logits = _InspectActivation.apply("shift_logits", None, GlobalSubscriberManager.get_run_context(), shift_logits)
+ shift_logits = inspect_activation("shift_logits", shift_logits)
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
@ -113,7 +114,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
return loss
```
Be noted, make sure the activation name (as the first argument of `_InspectActivation.apply`) is unique, otherwise
Be noted, make sure the activation name (as the first argument of `inspect_activation`) is unique, otherwise
stat file using the activation name will be overwritten by the last write. The dumped data are stored in the `output_dir`.

View file

@ -6,16 +6,17 @@
import sys
from typing import Callable, ClassVar, Dict, Optional
import onnx
import torch
import torch.utils.checkpoint
from onnx import ModelProto
from packaging import version
from torch.onnx import symbolic_helper
from onnxruntime.capi._pybind_state import register_miscellaneous_const_input, register_torch_autograd_function
from onnxruntime.training import ortmodule
from onnxruntime.training.utils import pytorch_dtype_to_onnx
from ._custom_op_symbolic_registry import pytorch_type_to_onnx, wrap_custom_export_function
from ._custom_op_symbolic_registry import wrap_custom_export_function
from ._fallback import ORTModuleONNXModelException, wrap_exception
from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version
@ -168,7 +169,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
if call_type == "d":
# Got a tensor variable.
tensor_args.append(arg)
scalar_type = pytorch_type_to_onnx(arg.type().scalarType())
scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType())
input_tensor_types.append(scalar_type)
input_tensor_ranks.append(arg.type().dim())
continue
@ -247,7 +248,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
output_tensor_ranks = []
for arg in n.outputs():
# Type of tensor's elements.
scalar_type = pytorch_type_to_onnx(arg.type().scalarType())
scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType())
output_tensor_types.append(scalar_type)
output_tensor_ranks.append(arg.type().dim())
@ -306,16 +307,14 @@ def _export_pt_1_10(g, n, *args, **kwargs):
_export = wrap_custom_export_function(_export_pt_1_10)
def _post_process_after_export(
exported_model: onnx.ModelProto, enable_custom_autograd_function: bool
) -> onnx.ModelProto:
def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto:
"""Post process the exported model."""
if enable_custom_autograd_function:
exported_model = _post_process_enabling_autograd_function(exported_model)
return exported_model
def _post_process_enabling_autograd_function(exported_model: onnx.ModelProto) -> onnx.ModelProto:
def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto:
# Loop all PythonOp, append "_ctx" as the first output.
index = 0
for node in exported_model.graph.node:

View file

@ -232,7 +232,7 @@ def call_python_forward_function(
for arg_index, (grad_flag, tensor_flag, arg) in enumerate(zip(requires_grad_flags, tensor_type_flags, args)):
if tensor_flag:
# Assume it's a DLPack tensor# and convert it to PyTorch tensor.
# Assume it's a DLPack tensor and convert it to PyTorch tensor.
# Note1:
# If it's first-time kernel invocation, input_indices_to_save_in_ctx is None, we do the
# copy for all tensor. Otherwise, we only copy the tensors whose indices are in

View file

@ -12,37 +12,10 @@ from packaging.version import Version
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args
from onnxruntime.training.utils import pytorch_dtype_to_onnx
from ._utils import get_runtime_pytorch_version
# Mapping from pytorch scalar type to onnx scalar type.
_CAST_PYTORCH_TO_ONNX = {
"Byte": torch.onnx.TensorProtoDataType.UINT8,
"Char": torch.onnx.TensorProtoDataType.INT8,
"Double": torch.onnx.TensorProtoDataType.DOUBLE,
"Float": torch.onnx.TensorProtoDataType.FLOAT,
"Half": torch.onnx.TensorProtoDataType.FLOAT16,
"Int": torch.onnx.TensorProtoDataType.INT32,
"Long": torch.onnx.TensorProtoDataType.INT64,
"Short": torch.onnx.TensorProtoDataType.INT16,
"Bool": torch.onnx.TensorProtoDataType.BOOL,
"ComplexFloat": torch.onnx.TensorProtoDataType.COMPLEX64,
"ComplexDouble": torch.onnx.TensorProtoDataType.COMPLEX128,
"BFloat16": torch.onnx.TensorProtoDataType.BFLOAT16,
# Not yet defined in torch.
# "Float8E4M3FN": torch.onnx.TensorProtoDataType.FLOAT8E4M3FN,
# "Float8E4M3FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E4M3FNUZ,
# "Float8E5M2": torch.onnx.TensorProtoDataType.FLOAT8E5M2,
# "Float8E5M2FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E5M2FNUZ,
"Undefined": torch.onnx.TensorProtoDataType.UNDEFINED,
}
def pytorch_type_to_onnx(scalar_type: str) -> torch.onnx.TensorProtoDataType:
try:
return torch.onnx.JitScalarType.from_name(scalar_type).onnx_type()
except AttributeError:
return _CAST_PYTORCH_TO_ONNX[scalar_type]
def wrap_custom_export_function(original_func: Callable) -> Callable:
"""This function is to wrap the custom export function to make sure it can be used by different versions of PyTorch.
@ -172,7 +145,7 @@ def cross_entropy_loss(g, node, logits, target, weight, reduction, ignore_index,
weight_casted,
ignore_index,
reduction_s=reduction,
output_type_i=pytorch_type_to_onnx(output_type.scalarType()),
output_type_i=pytorch_dtype_to_onnx(output_type.scalarType()),
outputs=2,
)
output.setType(output_type)
@ -199,10 +172,16 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
output = g.op(
"org.pytorch.aten::ATen", weight, indices, padding_idx, scale_grad_by_freq, sparse, operator_s="embedding"
)
indices_shape = _get_tensor_sizes(indices)
if indices_shape is not None and hasattr(weight.type(), "with_sizes"):
output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)])
output.setType(output_type)
try:
# Tolerant to the case when sizes of indices are not available or not usable (for example
# when DeepSpeed stage3 enabled, all weights size is (0), this will fail.)
indices_shape = _get_tensor_sizes(indices)
if indices_shape is not None and hasattr(weight.type(), "with_sizes"):
output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)])
output.setType(output_type)
except IndexError:
output.setType(weight.type())
return output

View file

@ -20,6 +20,7 @@ import onnxruntime
from onnxruntime.capi import _pybind_state as C
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from onnxruntime.training.utils import ORTModelInputOutputSchemaType
from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
from ._custom_autograd_function_exporter import _post_process_after_export
@ -140,6 +141,10 @@ class GraphExecutionManager(GraphExecutionInterface):
register_triton_op_executor()
if self._runtime_options.enable_zero_stage3_support:
# Cannot toggle feature enabling/disabling after the first time enabled.
configure_ort_compatible_zero_stage3()
def _get_torch_gpu_allocator_function_addresses(self):
if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available():
# CPP extension to get torch GPU allocator's alloc and free function addresses
@ -281,7 +286,11 @@ class GraphExecutionManager(GraphExecutionInterface):
# All required models have already been exported previously
return False
self._set_device_from_module(inputs, kwargs)
self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs)
from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step
with no_increase_global_step():
self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs)
if self._debug_options.save_onnx_models.save:
self._onnx_models.save_exported_model(
self._debug_options.save_onnx_models.path,
@ -311,7 +320,6 @@ class GraphExecutionManager(GraphExecutionInterface):
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend)
# Be noted: rank 0 log only is controlled by logger configured in _logger.py
torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO
self._logger.info("Exporting the PyTorch model to ONNX...")
# Setup dynamic axes for onnx model
self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs)
@ -327,6 +335,8 @@ class GraphExecutionManager(GraphExecutionInterface):
# FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs
self._flattened_module._input_info = self._input_info
self._logger.info("Exporting the PyTorch model to ONNX...")
# Leverage cached model if available
cache_dir = self._runtime_options.ortmodule_cache_dir
if cache_dir:
@ -659,6 +669,14 @@ class GraphExecutionManager(GraphExecutionInterface):
)
)
feature_map.append(
(
"ZeRO Stage3 Support",
self._runtime_options.enable_zero_stage3_support,
"Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0",
)
)
mode = "training" if self._export_mode == torch.onnx.TrainingMode.TRAINING else "inference"
mode = f"{_logger.LogColor.UNDERLINE}{mode}{_logger.LogColor.ENDC}"

View file

@ -539,6 +539,7 @@ def parse_outputs_for_onnx_export_and_extract_schema(
output_names = None
output_dynamic_axes = None
is_deepcopy = False
logger.info("Running model forward to infer output schema and dynamic axes...")
with torch.no_grad():
# Deepcopy inputs, since input values may change after model run.
sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs)

View file

@ -288,6 +288,9 @@ class _RuntimeOptions:
# Cache exported model
self.ortmodule_cache_dir = ""
# Experimental features.
self.enable_zero_stage3_support = False # Once enabled, cannot be disabled.
# Override the feature config if it exists in os env.
self._override_from_env_vars()
@ -365,3 +368,7 @@ class _RuntimeOptions:
if "ORTMODULE_CACHE_DIR" in os.environ:
self._logger.info("ORTModule cache optimization is ON.")
self.ortmodule_cache_dir = os.getenv("ORTMODULE_CACHE_DIR")
# Experimental features.
if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1:
self.enable_zero_stage3_support = True

View file

@ -9,6 +9,7 @@ from onnxruntime.training.utils.torch_io_helper import (
extract_data_and_schema,
unflatten_data_using_schema,
)
from onnxruntime.training.utils.torch_type_map import pytorch_dtype_to_onnx
__all__ = [
"PrimitiveType",
@ -16,4 +17,5 @@ __all__ = [
"ORTModelInputOutputSchemaType",
"extract_data_and_schema",
"unflatten_data_using_schema",
"pytorch_dtype_to_onnx",
]

View file

@ -3,14 +3,35 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import torch
__all__ = [
"StatisticsSubscriber",
"GlobalSubscriberManager",
"_InspectActivation",
"inspect_activation",
"ZeROOffloadSubscriber",
"configure_ort_compatible_zero_stage3",
]
from ._statistics_subscriber import StatisticsSubscriber
from ._subscriber_manager import SubscriberManager, _InspectActivation
from ._statistics_subscriber import StatisticsSubscriber, _InspectActivation
from ._subscriber_manager import SubscriberManager
from ._zero_offload_subscriber import ZeROOffloadSubscriber, configure_ort_compatible_zero_stage3
# Define a global uninitialized subscriber manager for usage where it is needed by different Python files.
GlobalSubscriberManager = SubscriberManager()
def inspect_activation(activation_name: str, tensor: torch.Tensor) -> torch.Tensor:
for sub in GlobalSubscriberManager._subscribers:
if isinstance(sub, StatisticsSubscriber):
return _InspectActivation.apply(
activation_name,
None,
GlobalSubscriberManager.get_run_context(),
tensor,
sub.module_post_forward_impl,
sub.module_pre_backward_impl,
)
raise RuntimeError("StatisticsSubscriber is not registered, cannot inspect activation.")

View file

@ -7,11 +7,90 @@ import os
import shutil
import warnings
from pathlib import Path
from typing import Union
from typing import List, Optional, Tuple, Union
import onnx
import torch
from ._subscriber_base import SubscriberBase
from ._subscriber_base import RuntimeStates, SubscriberBase
class _InspectActivation(torch.autograd.Function):
"""
This class is used to run the subscriber's forward and backward functions.
The function will be called by two kinds of callers:
1. SubscriberManager calls it for each registered nn.Module.
2. Users who want to inspect the activation tensor at any place of model definition code.
"""
@staticmethod
def forward(
ctx,
activation_name: str,
module_idx: Optional[int],
run_ctx: RuntimeStates,
input_tensor: torch.Tensor,
module_post_forward,
module_pre_backward,
):
"""
Args:
ctx: context object to store intermediate information.
activation_name: the name of the activation tensor.
module_idx: unit id of the module (address of the module instance).
run_ctx: runtime context.
For call case 2 - need retrieve the runtime state from GlobalSubscriberManager.
input_tensor: the activation tensor.
Make sure there is a same number of `tensor` type inputs and outputs.
This is enforced by ORT's PythonOp's schema check.
"""
depth = -1
if module_idx is not None:
depth = run_ctx.global_states.module_index_to_depth[module_idx]
input_tensor_copied = None
if input_tensor is None or not isinstance(input_tensor, torch.Tensor):
input_tensor_copied = input_tensor
else:
input_tensor_copied = input_tensor.detach().clone()
ctx.current_step = run_ctx.global_states.execution_step
ctx.name = activation_name
ctx.id = module_idx
ctx.depth = depth
ctx.module_pre_backward = module_pre_backward
module_post_forward(input_tensor_copied, depth, activation_name, ctx.current_step)
return input_tensor.detach() if input_tensor is not None else None
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
val = None
if grad_output is None or not isinstance(grad_output, torch.Tensor):
val = grad_output
else:
val = grad_output.detach().clone()
ctx.module_pre_backward(val, ctx.depth, ctx.name, ctx.current_step)
return (
None,
None,
None,
grad_output.detach() if grad_output is not None else None,
None,
None,
)
@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes
class StatisticsSubscriber(SubscriberBase):
@ -68,6 +147,15 @@ class StatisticsSubscriber(SubscriberBase):
"Set override_output_dir=True for StatisticsSubscriber if this is the intention."
)
def post_forward_tensor_apply_impl(
self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor
) -> torch.Tensor:
module_index = run_rtx.global_states.module_to_module_index[module]
name = f"{module.__class__.__name__}_{module_index}_{tensor_index}th_output"
return _InspectActivation.apply(
name, module_index, run_rtx, tensor, self.module_post_forward_impl, self.module_pre_backward_impl
)
def module_post_forward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int):
output_file_path = os.path.join(f"{self._output_dir}", f"step_{step}")
return self._summarize_activations(activation, depth, name, output_file_path, True)

View file

@ -5,34 +5,54 @@
import sys
from typing import Union
from typing import Optional, Tuple
import torch
from onnxruntime.training.utils import ORTModelInputOutputType
class RuntimeStates:
"""
A data struct holding states for runtime context.
> Global states that are one-time collected during model hook registration. A global execution step is
also initialized to reflect how many steps have been executed, it will get updated after each step
completes its forward path.
"""
class _GlobalStates:
def __init__(self):
# Used to track current execution step, e.g. how many forward/backward path is called.
self.execution_step = 0
# Used to store the depth of each module, which indicate the indentation level of the module.
self.module_index_to_depth = {}
# Used to store the unique id of each sequential activation.
self.module_to_module_index = {}
def __init__(self):
self.global_states = RuntimeStates._GlobalStates()
class SubscriberBase:
"""
Base class for all module hook subscribers.
Currently, the hook here only means post-forward hook and pre-backward hook.
A module hook subscriber is a class that implements the `module_post_forward_impl` and `module_pre_backward_impl`
function.
> The post_forward hook is called after the activation is generated in the forward path.
> The pre_backward hook is called before the activation gradient is computed.
A module hook subscriber is a class that allow define custom actions to be executed during the nn.Module's hooks.
Two types of APIs can be used to define custom actions:
1. Module level interfaces:
pre_forward_module_apply - called inside the nn.Module's pre-forward hook.
post_forward_module_apply - called inside the nn.Module's post-forward hook.
post_forward_outmost_module_apply - called inside the nn.Module's post-forward hook, but only for the outmost module.
2. Tensor level interfaces:
pre_forward_tensor_apply - called inside the nn.Module's pre-forward hook, for each input tensor.
post_forward_tensor_apply - called inside the nn.Module's post-forward hook, for each output tensor.
The post_forward path:
Module_A generates activation tensor_a --> Post forward hook (calling subscribers' forward one by one) -->
Module_B generates activation tensor_b --> ...
The pre_backward path:
Module_B backward run, tensor_b's gradient is computed as tensor_b_grad -->
Pre-backward hook (calling subscribers' backward one by one) -->
Module_A backward run, tensor_a's gradient is computed as tensor_a_grad
Be noted: the "Pre"/"Post" is described from the perspective of Module_A.
For ORT runs, tensor's flows are important, that's the reason we have tensor input as function input,
and tensor output as function output for all the APIs.
With this, the overall flow can be traced as a data flow graph (DAG).
"""
def __init__(self, start_step: Union[None, int], end_step: Union[None, int]):
def __init__(self, start_step: Optional[int], end_step: Optional[int]):
"""
Steps in [start_step, end_step) will run the subscriber's actions, and other steps will skip.
If start_step is None, 0 is given; if end_step is None, sys.maxsize is given.
@ -40,34 +60,152 @@ class SubscriberBase:
self._start_step: int = start_step if start_step is not None else 0
self._end_step: int = end_step if end_step is not None else sys.maxsize
def module_post_forward(self, activation: torch.Tensor, depth: int, name: str, step: int):
"""
This function will be run after the torch Module forward is completed.
def pre_forward_module_apply(
self,
run_rtx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
kwargs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:
"""This function is called inside the nn.Module's pre-forward hook.
Args:
activation: Tensor to be inspected.
depth: The indent level of the torch Module generating `activation`.
name: The unique name for the `activation`.
step: Current execution step.
"""
if self._start_step <= step < self._end_step:
self.module_post_forward_impl(activation, depth, name, step)
run_rtx (RuntimeStates): The runtime states of SubscriberManager.
module (torch.nn.Module): The module that is being executed.
args (ORTModelInputOutputType): The positional arguments that are passed to the module's pre-forward hook.
kwargs (ORTModelInputOutputType): The keyword arguments that are passed to the module's pre-forward hook.
Returns:
Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated args and kwargs.
def module_pre_backward(self, activation: torch.Tensor, depth: int, name: str, step: int):
"""
This function will be run before the torch Module backward run.
if self._need_skip_step(run_rtx.global_states.execution_step):
return args, kwargs
updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_rtx, module, args, kwargs)
return updated_args, updated_kwargs
def pre_forward_module_apply_impl(
self,
run_rtx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
kwargs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:
return args, kwargs
def pre_forward_tensor_apply(
self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor
) -> torch.Tensor:
"""This function is called inside the nn.Module's pre-forward hook.
Args:
activation: Tensor to be inspected.
depth: The indent level of the torch Module generating `activation`.
name: The unique name for the `activation`.
step: Current execution step.
run_rtx (RuntimeStates): The runtime states of SubscriberManager.
module (torch.nn.Module): The module that is being executed.
tensor_index (int): The index of the tensor in the input tensor list.
tensor (torch.Tensor): The tensor is one of module's forward inputs.
"""
if self._start_step <= step < self._end_step:
self.module_pre_backward_impl(activation, depth, name, step)
if self._need_skip_step(run_rtx.global_states.execution_step):
return tensor
def module_post_forward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int):
raise NotImplementedError()
return self.pre_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor)
def module_pre_backward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int):
raise NotImplementedError()
def pre_forward_tensor_apply_impl(
self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor
) -> torch.Tensor:
return tensor
def post_forward_module_apply(
self,
run_rtx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
outputs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:
"""This function is called inside the nn.Module's post-forward hook.
Args:
run_rtx (RuntimeStates): The runtime states of SubscriberManager.
module (torch.nn.Module): The module that is being executed.
args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward
hook as input.
outputs (ORTModelInputOutputType): The outputs arguments that are passed to the module's post-forward
hook as input.
Returns:
Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs.
"""
if self._need_skip_step(run_rtx.global_states.execution_step):
return args, outputs
return self.post_forward_module_apply_impl(run_rtx, module, args, outputs)
def post_forward_module_apply_impl(
self,
run_rtx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
outputs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:
return args, outputs
def post_forward_tensor_apply(
self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor
) -> torch.Tensor:
"""This function is called inside the nn.Module's post-forward hook.
Args:
run_rtx (RuntimeStates): The runtime states of SubscriberManager.
module (torch.nn.Module): The module that is being executed.
tensor_index (int): The index of the tensor in the output tensor list.
tensor (torch.Tensor): The tensor is one of module's forward outputs.
Returns:
torch.Tensor: Updated tensor.
"""
if self._need_skip_step(run_rtx.global_states.execution_step):
return tensor
return self.post_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor)
def post_forward_tensor_apply_impl(
self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor
) -> torch.Tensor:
return tensor
def post_forward_outmost_module_apply(
self,
run_rtx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
outputs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:
"""This function is called inside the outmost nn.Module's post-forward hook.
Args:
run_rtx (RuntimeStates): The runtime states of SubscriberManager.
module (torch.nn.Module): The module that is being executed.
args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward
hook as input.
outputs (ORTModelInputOutputType): The outputs arguments that are passed to the module's post-forward
hook as input.
Returns:
Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs.
"""
if self._need_skip_step(run_rtx.global_states.execution_step):
return args, outputs
return self.post_forward_outmost_module_apply_impl(run_rtx, module, args, outputs)
def post_forward_outmost_module_apply_impl(
self,
run_rtx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
outputs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:
return args, outputs
def _need_skip_step(self, current_step: int) -> bool:
return current_step < self._start_step or current_step >= self._end_step

View file

@ -3,114 +3,31 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from collections import abc
from typing import Callable, List, Optional, Tuple, Union
import inspect
from contextlib import contextmanager
from typing import List, Optional, Set, Tuple, Union
import onnx
import torch
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.utils import extract_data_and_schema, unflatten_data_using_schema
from ._subscriber_base import SubscriberBase
from ._subscriber_base import RuntimeStates, SubscriberBase
ORT_NO_INCREASE_GLOBAL_STEP = [False]
class _RuntimeStates:
@contextmanager
def no_increase_global_step():
"""During ONNX model export phase, forward run is triggered, but we don't want to increase the global step, then
Then the first iteration run will still start with 0, aligned with PyTorch's first iteration run.
"""
A data struct holding states for runtime context. Tho kinds of states are included:
> Global states that are one-time collected during model hook registration. A global execution step is
also initialized to reflect how many steps have been executed, it will get updated after each step
completes its forward path.
> Intra-execution step states, initialized and cleaned up intended only for current execution step.
Usually, it carries intermediate information during the model execution.
"""
class _GlobalStates:
def __init__(self):
# Used to track current execution step, e.g. how many forward/backward path is called.
self.execution_step = 0
# Used to store the depth of each module, which indicate the indentation level of the module.
self.module_index_to_depth = {}
# Used to store the unique id of each sequential activation.
self.module_to_module_index = {}
self.subscribers = set()
class _ExecutionStepStates:
def __init__(self):
# Used to store the activation tensor names, if already handled, then skipped.
# Need to clear after each step.
self.observed_activation_names = {}
def __init__(self):
self.global_states = _RuntimeStates._GlobalStates()
self.reset_step_states()
def reset_step_states(self):
self.execution_step_states = _RuntimeStates._ExecutionStepStates()
class _InspectActivation(torch.autograd.Function):
"""
This class is used to run the subscriber's forward and backward functions.
The function will be called by two kinds of callers:
1. SubscriberManager calls it for each registered nn.Module.
2. Users who want to inspect the activation tensor at any place of model definition code.
"""
@staticmethod
def forward(
ctx, activation_name: str, module_idx: Optional[int], run_ctx: _RuntimeStates, input_tensor: torch.Tensor
):
"""
Args:
ctx: context object to store intermediate information.
activation_name: the name of the activation tensor.
module_idx:
For call case 1 - the unique id of the module that the activation belongs to, it is detected by the
SubscriberManager automatically.
For call case 2 - e.g, _InspectActivation is called by users (NOT by SubscriberManager), module_idx can
be None.
run_ctx: runtime context.
For call case 2 - need retrieve the runtime state from GlobalSubscriberManager.
input_tensor: the activation tensor.
Make sure there is a same number of `tensor` type inputs and outputs.
This is enforced by ORT's PythonOp's schema check.
"""
depth = -1
if module_idx is not None:
depth = run_ctx.global_states.module_index_to_depth[module_idx]
input_tensor_copied = None
if input_tensor is None or not isinstance(input_tensor, torch.Tensor):
input_tensor_copied = input_tensor
else:
input_tensor_copied = input_tensor.detach().clone()
ctx.current_step = run_ctx.global_states.execution_step
ctx.name = activation_name
ctx.id = module_idx
ctx.depth = depth
ctx.subscribers = run_ctx.global_states.subscribers
# Run subscribers sequentially.
for subscriber in run_ctx.global_states.subscribers:
subscriber.module_post_forward(input_tensor_copied, depth, activation_name, ctx.current_step)
return input_tensor.detach() if input_tensor is not None else None
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
val = None
if grad_output is None or not isinstance(grad_output, torch.Tensor):
val = grad_output
else:
val = grad_output.detach().clone()
for subscriber in ctx.subscribers:
subscriber.module_pre_backward(val, ctx.depth, ctx.name, ctx.current_step)
return None, None, None, grad_output.detach() if grad_output is not None else None
try:
ORT_NO_INCREASE_GLOBAL_STEP[0] = True
yield
finally:
ORT_NO_INCREASE_GLOBAL_STEP[0] = False
@staticmethod
def infer_shape(
@ -122,8 +39,7 @@ class _InspectActivation(torch.autograd.Function):
class _IncrementStep(torch.autograd.Function):
"""
This class is used to manage the global execution step, e.g.
"""This class is used to manage the global execution step, e.g.
global step increment by one, once a full forward path is completed and the state clear.
This autograd Function is registered as a post-forward hook to the root module. So once the root
@ -132,45 +48,24 @@ class _IncrementStep(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, run_ctx: _RuntimeStates, input_tensor: torch.Tensor):
"""
Make sure there is a same number of `tensor` inputs and outputs.
def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
"""Make sure there is the same number of `tensor` inputs and outputs.
This is enforced by ORT's PythonOp's schema check.
"""
ctx.current_step = run_ctx.global_states.execution_step
ctx.run_ctx = run_ctx
# We cannot do the step incremental here. Imagine the outside-most module has multiple outputs,
# we need to increase the step only at the very last output handling.
# We avoid the complexity to probe the last output handling, and instead, we assume once
# the very first backward of the outside-most module is called, then the forward pass MUST be completed.
if ctx.current_step >= 0:
print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}")
# Be noted: it is not safe to register _IncrementStep only for one of the outputs of the outside-most module,
# because we are not sure which output branch is executed earlier, for example.
# OuterMostModuleOutputs
# / \
# OuterMostModuleOutputs_0_0th_output OuterMostModuleOutputs_0_1th_output
# | |
# PythonOp(_InspectActivation) PythonOp(_InspectActivation)
# | |
# PythonOp(_IncrementStep) graph output
# |
# graph output
# The PythonOp(_InspectActivation) (who relies on global step) after 1th output is possible
# to run before or after PythonOp(_IncrementStep), so increasing the step is not safe.
if ORT_NO_INCREASE_GLOBAL_STEP[0] is False:
ctx.run_ctx.global_states.execution_step += 1
return input_tensor.detach() if isinstance(input_tensor, torch.Tensor) else input_tensor
return tuple(t.detach().requires_grad_(t.requires_grad) for t in input_tensor_list)
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# In case there are multiple backward calls for multiple outputs of the outside-most module.
if ctx.current_step == ctx.run_ctx.global_states.execution_step:
if ctx.current_step >= 0:
print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}")
ctx.run_ctx.global_states.execution_step += 1
ctx.run_ctx.reset_step_states()
return None, grad_output.detach() if isinstance(grad_output, torch.Tensor) else grad_output
def backward(ctx, *grad_output: Tuple[Optional[torch.Tensor], ...]) -> Tuple[Optional[torch.Tensor], ...]:
return (None, *tuple(g for g in grad_output))
@staticmethod
def infer_shape(
@ -182,23 +77,25 @@ class _IncrementStep(torch.autograd.Function):
class SubscriberManager:
"""
This class is used to manage all the subscribers and register the post-forward hook to the root module.
`subscribe()` is used to register a list of subscribers.
"""This class is used to manage all the subscribers and register subscribers' custom actions as PyTorch hooks
to the nn.Modules.
Currently, the hook handled here is post forward hook for nn.Module. The hook is registered for all nn.Modules
recursively. Each hook inserts a PythonOp for every tensor output generated by the corresponding module.
Each subscriber implementation is called in the PythonOp's forward function, and backward function.
For the module-level/tensor_level custom actions defined by subscribers, they are registered as corresponding
PyTorch hooks in the sequence of the subscribers' registration order.
There is one special handling for global step increment and state clear. A post-forward hook is registered
for the outside-most module, which is the root module. In that hook, _IncrementStep is called, which will
increase the step by 1 once the very first time its backward is called (check _IncrementStep for details).
increase the step by 1 once the post forward hook is called if running without no_increase_global_step().
`no_increase_global_step` is used to skip the step increment during ONNX model export.
"""
def __init__(self):
self._run_ctx: _RuntimeStates = _RuntimeStates()
self._run_ctx = RuntimeStates()
self._subscribers: Set[SubscriberBase] = set()
self._pre_forward_hooks = []
self._post_forward_hooks = []
def subscribe(self, module: Union[torch.nn.Module, ORTModule], subscribers: List[SubscriberBase]):
def subscribe(self, module: torch.nn.Module, subscribers: List[SubscriberBase]):
"""
The API is called externally to register hooks that are implicitly defined by subscribers.
Each time all global states will be cleaned up once called.
@ -207,52 +104,110 @@ class SubscriberManager:
raise ValueError("module must be a torch.nn.Module instance")
self._reset_all_states()
self._subscribers.clear()
if isinstance(module, ORTModule):
module = module.module
try:
# Put the import here to avoid the module level dependency on onnxruntime.training.ortmodule
from onnxruntime.training.ortmodule import ORTModule
if isinstance(module, ORTModule):
module = module.module
except ImportError:
pass
for subscriber in subscribers:
if not isinstance(subscriber, SubscriberBase):
raise ValueError("subscriber must be a SubscriberBase instance")
self._run_ctx.global_states.subscribers.add(subscriber)
self._subscribers.add(subscriber)
self._initialize(module)
def get_run_context(self) -> _RuntimeStates:
def get_subscriber(self, subscriber_type: type) -> SubscriberBase:
for subscriber in self._subscribers:
if isinstance(subscriber, subscriber_type):
return subscriber
raise RuntimeError(f"Subscriber {subscriber_type} is not registered.")
def get_run_context(self) -> RuntimeStates:
return self._run_ctx
def _reset_all_states(self):
self._run_ctx = _RuntimeStates()
self._pre_forward_hooks.clear()
self._post_forward_hooks.clear()
self._run_ctx = RuntimeStates()
def _initialize(self, module: torch.nn.Module):
"""
Register hooks for the specified module.
"""
if len(self._run_ctx.global_states.subscribers) == 0:
"""Register hooks for the specified module."""
if len(self._subscribers) == 0:
raise RuntimeError("No subscribers are registered.")
def _pre_forward_outmost_module_hook(module, module_inputs):
# This check is to support the case where module is first registered in the subscriber manager,
# then the module and hook are copied, when new module instance runs to the hook, the global states
# are not reset, so the logic depends on the global states will fail. So in the outer-most pre-forward hook
# we reset the global states.
# Be noted, the first run anyway will run in PyTorch.
if module not in self._run_ctx.global_states.module_to_module_index:
import warnings
warnings.warn(
"Initialize global states for the first time, this should only happen once for each outmost module."
)
self._initialize_one_time_global_states(module)
return module_inputs
module.register_forward_pre_hook(_pre_forward_outmost_module_hook)
next_module_index = [0]
# Register post forward hook for every module, inside the hook, we loop every tensor output of the module,
# and wrap it with an autograd Function called _InspectActivation (which takes in a tensor and returns the same
# tensor). In this way, we keep ORT and PyTorch run have the same boundary to check activation equality.
self._register_hooks_recursively(module, 1, next_module_index)
# Register post forward hook for the outside-most module, then we increase the dump step.
# Be noted, if backward is not triggered, the global dump step remains the original number,
# which means the subsequent run will override the previous dump files. This indeed happens to imagine ORTModule
# firstly export graph (run the forward only), after the gradient graph is built, another forward+backward is
# triggered, override the previous dump files.
def _post_forward_outmost_module_hook(module, _, module_outputs):
def _apply_to_tensors_func(_, outputs):
return _IncrementStep.apply(self._run_ctx, outputs)
def _post_forward_outmost_module_hook(module, module_inputs, module_outputs):
# Call post outmost module forward custom actions for subscribers
for sub in self._subscribers:
module_inputs, module_outputs = sub.post_forward_outmost_module_apply(
self._run_ctx, module, module_inputs, module_outputs
)
return self._apply_function_to_tensors(module, module_outputs, _apply_to_tensors_func)
flatten_output_tensor_list, output_schema = extract_data_and_schema(module_outputs)
output_tensors = _IncrementStep.apply(self._run_ctx, *flatten_output_tensor_list)
restored_outputs = unflatten_data_using_schema(output_tensors, output_schema)
return restored_outputs
module.register_forward_hook(_post_forward_outmost_module_hook)
def _initialize_one_time_global_states(self, module: torch.nn.Module):
def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: List[int]):
"""
Called to register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s,
this function is called recursively by passing in `next_module_index` - a list of int to maintain a
global incremental unique module id.
Args:
module: torch.nn.Module to register hook.
depth: the indent of the module compared with the outside-most Module.
next_module_index: list of int, carrying a global unique module index that can be used next.
"""
module_index = next_module_index[0]
module.id = module_index # STAGE3WARN: needed by DeepSpeed
self._run_ctx.global_states.module_index_to_depth[module_index] = depth
self._run_ctx.global_states.module_to_module_index[module] = module_index
for child in module.children():
if (
isinstance(child, torch.nn.Module)
and child not in self._run_ctx.global_states.module_to_module_index
):
next_module_index[0] += 1
_reset_recursively(child, depth + 1, next_module_index)
next_module_index = [0]
_reset_recursively(module, 1, next_module_index)
def _register_hooks_recursively(self, module: torch.nn.Module, depth: int, next_module_index: List[int]):
"""
Called to register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s,
"""Register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s,
this function is called recursively by passing in `next_module_index` - a list of int to maintain a
global incremental unique module id.
@ -262,6 +217,7 @@ class SubscriberManager:
next_module_index: list of int, carrying a global unique module index that can be used next.
"""
module_index = next_module_index[0]
module.id = module_index # STAGE3WARN: needed by DeepSpeed
self._run_ctx.global_states.module_index_to_depth[module_index] = depth
self._run_ctx.global_states.module_to_module_index[module] = module_index
@ -270,69 +226,54 @@ class SubscriberManager:
next_module_index[0] += 1
self._register_hooks_recursively(child, depth + 1, next_module_index)
def _post_forward_module_hook(module, _, module_outputs):
if module in self._run_ctx.global_states.module_to_module_index and isinstance(module, torch.nn.Module):
module_index = self._run_ctx.global_states.module_to_module_index[module]
def _pre_forward_module_with_kwargs_hook(module, module_inputs, kwargs):
# Module level hook
for sub in self._subscribers:
module_inputs, kwargs = sub.pre_forward_module_apply(self._run_ctx, module, module_inputs, kwargs)
def _apply_to_tensors_func(index, activation_tensor):
name = f"{module.__class__.__name__}_{module_index}_{index}th_output"
if name not in self._run_ctx.execution_step_states.observed_activation_names:
self._run_ctx.execution_step_states.observed_activation_names[name] = True
return _InspectActivation.apply(name, module_index, self._run_ctx, activation_tensor)
# Tensor level hook
flatten_positional_input_tensor_list, input_schema = extract_data_and_schema(module_inputs)
flatten_keyword_input_tensor_list, keyword_input_schema = extract_data_and_schema(kwargs)
return activation_tensor
for sub in self._subscribers:
tensor_list = []
for tensor_index, tensor in enumerate(flatten_positional_input_tensor_list):
tensor_list.append(sub.pre_forward_tensor_apply(self._run_ctx, module, tensor_index, tensor))
flatten_positional_input_tensor_list = tensor_list
return self._apply_function_to_tensors(module, module_outputs, _apply_to_tensors_func)
return module_outputs
tensor_list = []
for tensor_index, tensor in enumerate(flatten_keyword_input_tensor_list):
tensor_list.append(sub.pre_forward_tensor_apply(self._run_ctx, module, tensor_index, tensor))
flatten_keyword_input_tensor_list = tensor_list
module.register_forward_hook(_post_forward_module_hook)
module_inputs = unflatten_data_using_schema(flatten_positional_input_tensor_list, input_schema)
kwargs = unflatten_data_using_schema(flatten_keyword_input_tensor_list, keyword_input_schema)
def _is_builtin_type(self, obj):
# https://stackoverflow.com/a/17795199
return obj.__class__.__module__ in ["__builtin__", "builtins"]
return module_inputs, kwargs
def _apply_function_to_tensors(self, module: torch.nn.Module, data, func: Callable):
"""
Apply func to all tensors in the given object.
def _pre_forward_module_hook(module, module_inputs):
return _pre_forward_module_with_kwargs_hook(module, module_inputs, {})
Args:
module: the module that generates the tensors.
data: the object that contains activation tensors.
func: the function to apply to the tensors.
"""
tensor_output_idx: List[int] = [0]
def _post_forward_module_hook(module, module_inputs, module_outputs):
# Module level hook
for sub in self._subscribers:
_, module_outputs = sub.post_forward_module_apply(self._run_ctx, module, module_inputs, module_outputs)
def _apply_to_tensors_by_flatten(
module: torch.nn.Module,
index_for_tensor_output: List[int],
outputs,
func: Callable,
):
if isinstance(outputs, abc.Sequence):
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_by_flatten(module, index_for_tensor_output, output, func)
touched_outputs.append(touched_output)
return outputs.__class__(touched_outputs)
# Tensor level hook
flatten_output_tensor_list, output_schema = extract_data_and_schema(module_outputs)
for sub in self._subscribers:
tensor_list = []
for tensor_index, tensor in enumerate(flatten_output_tensor_list):
tensor_list.append(sub.post_forward_tensor_apply(self._run_ctx, module, tensor_index, tensor))
flatten_output_tensor_list = tensor_list
if isinstance(outputs, abc.Mapping):
# apply inplace to avoid recreating dict inherited objects
for key in outputs:
outputs[key] = _apply_to_tensors_by_flatten(
module,
index_for_tensor_output,
outputs[key],
func,
)
return outputs
return unflatten_data_using_schema(flatten_output_tensor_list, output_schema)
if isinstance(outputs, torch.Tensor):
cur_id = index_for_tensor_output[0]
index_for_tensor_output[0] += 1
return func(cur_id, outputs)
if not self._is_builtin_type(outputs):
raise RuntimeError(f"Unknown type {type(outputs)}")
return outputs
return _apply_to_tensors_by_flatten(module, tensor_output_idx, data, func)
# "with_kwargs" is not available for low versions of PyTorch.
if "with_kwargs" in inspect.signature(module.register_forward_pre_hook).parameters:
self._pre_forward_hooks.append(
module.register_forward_pre_hook(_pre_forward_module_with_kwargs_hook, with_kwargs=True)
)
else:
self._pre_forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook))
self._post_forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))

View file

@ -0,0 +1,476 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import ctypes
import inspect
import warnings
from collections import OrderedDict
from types import CodeType, FunctionType
from typing import Callable, Dict, List, Optional, Tuple, Union
import onnx
import torch
from onnxruntime.training.utils import (
ORTModelInputOutputType,
extract_data_and_schema,
pytorch_dtype_to_onnx,
unflatten_data_using_schema,
)
from ._subscriber_base import RuntimeStates, SubscriberBase
# Used to monkey patch the original function
# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333
def _setup_zero_stage3_ort_compatible_hooks(self):
self.hierarchy = 0
from onnxruntime.training.utils.hooks import SubscriberManager, ZeROOffloadSubscriber
from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer
# Each DeepSpeed engine has a separate subscriber manager.
self._offload_subscriber_manager = SubscriberManager()
self._offload_subscriber_manager.subscribe(
self.module, [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)]
)
self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks)
self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks)
# Add top module to stack trace
global FWD_MODULE_STACK # noqa: PLW0602
FWD_MODULE_STACK.append(self.module)
# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104
# In the original logic, if bias is None, after export to ONNX, None becomes a constant, so backward op complains
# output count more than needed.
def _zero3_linear_wrap_ort_compatible(input, weight, bias=None):
from deepspeed.runtime.zero.linear import LinearFunctionForZeroStage3
return LinearFunctionForZeroStage3.apply(input, weight, bias)
class _ZeROOffloadOneTimeInitializer:
"""Store the hook functions from DeepSpeed ZeRO offload.
Hook functions code collected from DeepSpeed.
"""
def __init__(self):
self._code_store: OrderedDict[str, CodeType] = {}
def collect_code(self, function: Callable):
"""Collect the function `CodeType`, which is the code object of the function."""
code_obj = function.__code__
for c in code_obj.co_consts:
if inspect.iscode(c):
self._code_store[c.co_name] = c
_zero_offload_one_time_initializer = None
try:
# Have to import below explicitly, otherwise it complains about _apply_to_tensors_only not found.
# The hooks reference functions or classes in that file.
from deepspeed.runtime.zero.parameter_offload import * # noqa: F403
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload, _apply_to_tensors_only # noqa: F401
from deepspeed.utils import instrument_w_nvtx # noqa: F401
# Used to collect the hook functions's code object from DeepSpeed ZeRO offload, this should be initialized only once.
if _zero_offload_one_time_initializer is None:
_zero_offload_one_time_initializer = _ZeROOffloadOneTimeInitializer()
_zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload._register_hooks_recursively)
_zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload.setup_zero_stage3_hooks)
# This is the function to enable ORT ZeRO offload.
def configure_ort_compatible_zero_stage3():
"""Configure ZeRO stage3 to be ORT compatible.
This function will overwrite the original DeepSpeed ZeRO stage3 hooks to make it ORT compatible.
"""
# Only done once no matter how many times this function is called for different modules.
DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _setup_zero_stage3_ort_compatible_hooks
from deepspeed.runtime.zero.linear import zero3_linear_wrap
if torch.nn.functional.linear is zero3_linear_wrap:
torch.nn.functional.linear = _zero3_linear_wrap_ort_compatible
except ImportError as e:
warnings.warn(f"DeepSpeed import error {e}")
def configure_ort_compatible_zero_stage3():
raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.")
def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.parameter.Parameter]:
"""Retrieve the parameters for this module.
Logic adapted from
https://github.com/microsoft/DeepSpeed/blob/9d79cfd1e90cae9306dc1b5837d374b2c9489ac8/deepspeed/runtime/zero/partitioned_param_coordinator.py#L267
"""
from deepspeed.runtime.zero.partitioned_param_coordinator import iter_params
# Retrive the parameters that are not available for this module.
partitioned_params = [param for param in iter_params(module)]
return partitioned_params
def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]:
"""Retrieve all the parameters that are offloaded."""
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
all_offloaed_params = OrderedDict()
for name, param in module.named_parameters():
if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
all_offloaed_params[name] = param
return all_offloaed_params
class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
"""This function is a common bridge to call original PyTorch's
pre_forward_function and post_backward_function.
"""
@staticmethod
def forward(
ctx,
module,
pre_forward_with_kwargs_function,
post_backward_function,
args_schema,
kwargs_schema,
args_tensor_count,
kwargs_tensor_count,
*tensor_list,
):
"""
Args:
ctx: context object
module: the module to be called
pre_forward_with_kwargs_function: the function to be called before forward (PyTorch's pre_forward_function)
post_backward_function: the function to be called after backward (PyTorch's post_backward_function)
args_schema: the schema of the args, used to reconstruct the args in original form in
PyTorch's pre_forward_function's inputs.
kwargs_schema: the schema of the kwargs, used to reconstruct the kwargs in original form in
PyTorch's pre_forward_function's inputs.
args_tensor_count: the number of tensors in args.
kwargs_tensor_count: the number of tensors in kwargs.
tensor_list: the list of tensors, the first args_tensor_count tensors are args, the next
kwargs_tensor_count tensors are kwargs, the rest are the parameters for offload.
"""
args_tensors = tensor_list[:args_tensor_count]
kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count]
args = unflatten_data_using_schema(args_tensors, args_schema)
kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema)
# We will re-retrieve the parameter tensors other than use the one passed in input (of size 0 for
# those partitioned params).
# This is required for ORT run because in ORT graph, the tensor of size 0 will always be size 0
# (this step is not necessary for PyTorch run, because PyTorch will re-use the same tensor
# while .data got updated to full-sized data after pre_forward_with_kwargs_function is called).
partitioned_params = _get_params_for_current_module(module)
ctx.partitioned_params = partitioned_params
f_ret = pre_forward_with_kwargs_function(module, args, kwargs)
if f_ret is None:
updated_args, updated_kwargs = args, kwargs
else:
assert isinstance(f_ret, tuple)
updated_args, updated_kwargs = f_ret
ctx.module = module
ctx.post_backward_function = post_backward_function
updated_args_tensors, _ = extract_data_and_schema(updated_args)
updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs)
rets = tuple(updated_args_tensors + updated_kwargs_tensors)
rets += tuple([p.detach().requires_grad_(p.requires_grad) for p in partitioned_params])
# PyTorch exporter does not support an empty list of tensors, so we have this check.
assert len(rets) != 0
return rets
@staticmethod
def backward(ctx, *grads):
updated_grads = grads
if ctx.post_backward_function is not None:
ret = ctx.post_backward_function(ctx.module, grads)
if ret is not None:
updated_grads = ret
# TODO(pengwa) Update grad for partitioned parameters.
input_count = len(updated_grads) - len(ctx.partitioned_params)
zeros = [torch.zeros(0, dtype=p.dtype, device=p.device) for p in ctx.partitioned_params]
zero_grads = updated_grads[:input_count] + tuple(zeros)
return (None, None, None, None, None, None, None, *zero_grads)
@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
input_pointer_scalars_attr_name = "input_pointer_scalars"
found = [attr for attr in node.attribute if attr.name == input_pointer_scalars_attr_name]
assert len(found) == 1
input_pointer_scalars = found[0].ints
# Restore the nn.Module from the pointer.
module = ctypes.cast(input_pointer_scalars[0], ctypes.py_object).value
partitioned_params = _get_params_for_current_module(module)
tensor_output_shapes = tensor_input_shapes
tensor_output_dtypes = tensor_input_dtypes
start_offset = len(tensor_input_shapes) - len(partitioned_params)
for index, param in enumerate(partitioned_params):
tensor_output_shapes[start_offset + index] = list(param.ds_shape)
tensor_output_dtypes[start_offset + index] = pytorch_dtype_to_onnx(param.dtype)
assert len(tensor_output_shapes) == len(tensor_input_shapes)
assert len(tensor_output_dtypes) == len(tensor_input_dtypes)
return tensor_output_shapes, tensor_output_dtypes
class ORTZeROOffloadPostForwardFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
module,
post_forward_function,
pre_backward_function,
output_schema,
*output_tensors,
):
"""
Args:
ctx: context object
module: the module to be called
post_forward_function: the function to be called after forward (PyTorch's post_forward_function)
pre_backward_function: the function to be called before backward (PyTorch's pre_backward_function)
output_schema: the schema of the output, used to reconstruct the output in original form in
PyTorch's post_forward_function's inputs.
output_tensors: the list of tensors.
"""
outputs = unflatten_data_using_schema(output_tensors, output_schema)
# STAGE3WARN: _post_forward_module_hook's second argument `input is not used, so we just pass a None here.
updated_outputs = post_forward_function(module, None, outputs)
if updated_outputs is None:
updated_output_tensors = output_tensors
else:
updated_output_tensors, _ = extract_data_and_schema(updated_outputs)
ctx.module = module
ctx.pre_backward_function = pre_backward_function
rets = [o.detach().requires_grad_(o.requires_grad) for o in updated_output_tensors]
return tuple(rets)
@staticmethod
def backward(ctx, *grads):
updated_args = grads
if ctx.pre_backward_function is not None:
ret = ctx.pre_backward_function(ctx.module, grads)
if ret is not None:
updated_args = ret
return (None, None, None, None, *updated_args)
@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes
class _ZeROOffloadFunctions:
def __init__(self, one_time_init: _ZeROOffloadOneTimeInitializer, offloader) -> None:
self._function_store: OrderedDict[str, FunctionType] = {}
self._one_time_init = one_time_init
for name, code in self._one_time_init._code_store.items():
cell = self._create_closure_for_ds_hook_function(offloader)
self._function_store[name] = FunctionType(code, globals(), code.co_name, None, (cell,))
def get(self, name: str) -> FunctionType:
return self._function_store[name]
def _create_closure_for_ds_hook_function(self, offloader):
# https://stackoverflow.com/questions/17395338/how-to-access-a-function-inside-a-function
def make_closure_cell(_self):
def nested():
return _self
return nested.__closure__[0]
cell = make_closure_cell(offloader)
return cell
class ZeROOffloadSubscriber(SubscriberBase):
"""This subscriber is used to enable ZeRO Offload feature in a way compatible with ORTModule."""
def __init__(self, offloader, one_time_init: _ZeROOffloadOneTimeInitializer, enable_debug_info: bool = False):
super().__init__(None, None)
self._offloader = offloader
self._functions = _ZeROOffloadFunctions(one_time_init, self._offloader)
self._enable_debug_info = enable_debug_info
def pre_forward_module_apply_impl(
self,
run_rtx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
kwargs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:
"""This function is a dispatcher to call DeepSpeed stage3 pre forward hooks in sequence.
All hook functions can be retrieved from the function store, due to exporter only supports a list of tensors as
input and output for torch.autograd.Function, so we do flatten and unflatten here.
"""
args_tensors, args_schema = extract_data_and_schema(args)
kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs)
partitioned_params = _get_params_for_current_module(module)
_pre_forward_module_hook = self._functions.get("_pre_forward_module_hook")
args_tensor_count = len(args_tensors)
kwargs_tensor_count = len(kwargs_tensors)
def _wrap_pre_forward_module_hook(module, args, kwargs):
rets = _pre_forward_module_hook(module, args)
updated_args, updated_kwargs = args, kwargs
if rets is not None:
updated_args = rets
# STAGE3WARN: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration.
module.ds_grads_remaining = 0
return updated_args, updated_kwargs
all_tensors = args_tensors + kwargs_tensors + partitioned_params
self._check_all_tensor(all_tensors, module, "pre_forward_module_apply_impl input check")
rets = ORTZeROOffloadPreForwardFunction.apply(
module,
_wrap_pre_forward_module_hook,
None,
args_schema,
kwargs_schema,
args_tensor_count,
kwargs_tensor_count,
*all_tensors,
)
self._check_all_tensor(rets, module, "pre_forward_module_apply_impl output check")
updated_args_tensors = rets[:args_tensor_count]
updated_kwargs_tensors = rets[args_tensor_count : args_tensor_count + kwargs_tensor_count]
updated_args = unflatten_data_using_schema(updated_args_tensors, args_schema)
updated_kwargs = unflatten_data_using_schema(updated_kwargs_tensors, kwargs_schema)
_post_backward_module_hook = self._functions.get("_post_backward_module_hook")
# STAGE3WARN: Other part of _post_backward_module_hook can be traced correctly so we don't need to
# wrap with PythonOp.
updated_args = _post_backward_module_hook(module, updated_args)
return updated_args, updated_kwargs
def post_forward_module_apply_impl(
self,
run_rtx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
outputs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:
"""This function is a dispatcher to call DeepSpeed stage3 post forward hooks in sequence.
All hook functions can be retrieved from function store, due to exporter only supports a list of tensors as
input and output for torch.autograd.Function, so we do flatten and unflatten here.
"""
outputs_tensors, outputs_schema = extract_data_and_schema(outputs)
_post_forward_module_hook = self._functions.get("_post_forward_module_hook")
def _wrap_post_forward_module_hook(module, input, outputs):
# STAGE3WARN: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here.
from deepspeed.runtime.zero.partition_parameters import is_zero_param
updated_outputs = _post_forward_module_hook(module, input, outputs)
if updated_outputs:
for updated_output in updated_outputs:
# restore zero param attributes if those get stripped by `backward_function`
if not is_zero_param(updated_output) and is_zero_param(outputs):
updated_output.ds_param_alias = outputs
return updated_outputs
else:
return outputs
self._check_all_tensor(outputs_tensors, module, "post_forward_module_apply_impl input check")
updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply(
module, _wrap_post_forward_module_hook, None, outputs_schema, *outputs_tensors
)
self._check_all_tensor(updated_outputs_tensors, module, "post_forward_module_apply_impl output check")
assert len(updated_outputs_tensors) == len(outputs_tensors)
# WARN: we assume updated_output_tensors can REUSE the outputs_schema.
updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema)
_pre_backward_module_hook = self._functions.get("_pre_backward_module_hook")
# STAGE3WARN: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here.
# STAGE3WARN: part of the original _pre_backward_module_hook can be traced correctly so we moved them into
# _wrap_post_forward_module_hook above.
updated_outputs = _pre_backward_module_hook(module, None, updated_outputs)
return args, updated_outputs
def post_forward_outmost_module_apply_impl(
self,
run_rtx: RuntimeStates,
module: torch.nn.Module,
args: ORTModelInputOutputType,
outputs: ORTModelInputOutputType,
) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]:
outputs_tensors, outputs_schema = extract_data_and_schema(outputs)
_end_of_forward_hook = self._functions.get("_end_of_forward_hook")
self._check_all_tensor(outputs_tensors, module, "post_forward_outmost_module_apply_impl input check")
updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply(
module, _end_of_forward_hook, None, outputs_schema, *outputs_tensors
)
self._check_all_tensor(updated_outputs_tensors, module, "post_forward_outmost_module_apply_impl output check")
assert len(updated_outputs_tensors) == len(outputs_tensors)
updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema)
return args, updated_outputs
def _check_all_tensor(self, tensor_list: Tuple[torch.Tensor], module: torch.nn.Module, name: str):
if not self._enable_debug_info:
return
for t in tensor_list:
if not isinstance(t, torch.Tensor):
raise RuntimeError(f"{name} fail: {module.__class__.__name__}, input type: {type(t)}")

View file

@ -0,0 +1,47 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from typing import Union
import torch
# Mapping from pytorch scalar type to onnx scalar type.
_CAST_PYTORCH_TO_ONNX = {
"Byte": [torch.onnx.TensorProtoDataType.UINT8, torch.uint8],
"Char": [torch.onnx.TensorProtoDataType.INT8, torch.int8],
"Double": [torch.onnx.TensorProtoDataType.DOUBLE, torch.double],
"Float": [torch.onnx.TensorProtoDataType.FLOAT, torch.float],
"Half": [torch.onnx.TensorProtoDataType.FLOAT16, torch.half],
"Int": [torch.onnx.TensorProtoDataType.INT32, torch.int],
"Long": [torch.onnx.TensorProtoDataType.INT64, torch.int64],
"Short": [torch.onnx.TensorProtoDataType.INT16, torch.short],
"Bool": [torch.onnx.TensorProtoDataType.BOOL, torch.bool],
"ComplexFloat": [torch.onnx.TensorProtoDataType.COMPLEX64, torch.complex64],
"ComplexDouble": [torch.onnx.TensorProtoDataType.COMPLEX128, torch.complex128],
"BFloat16": [torch.onnx.TensorProtoDataType.BFLOAT16, torch.bfloat16],
# Not yet defined in torch.
# "Float8E4M3FN": torch.onnx.TensorProtoDataType.FLOAT8E4M3FN,
# "Float8E4M3FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E4M3FNUZ,
# "Float8E5M2": torch.onnx.TensorProtoDataType.FLOAT8E5M2,
# "Float8E5M2FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E5M2FNUZ,
"Undefined": [torch.onnx.TensorProtoDataType.UNDEFINED, None],
}
_DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _CAST_PYTORCH_TO_ONNX.items()}
def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType:
"""Converts a pytorch dtype or scalar type string to an onnx dtype."""
dtype = dtype_or_scalar_type
if isinstance(dtype, str):
if dtype not in _CAST_PYTORCH_TO_ONNX:
raise RuntimeError(f"Unsupported dtype {dtype}")
return _CAST_PYTORCH_TO_ONNX[dtype][0]
if dtype not in _DTYPE_TO_ONNX:
raise RuntimeError(f"Unsupported dtype {dtype}")
return _DTYPE_TO_ONNX[dtype]

View file

@ -8,7 +8,7 @@ import pytest
import torch
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber, _InspectActivation
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber, inspect_activation
class NeuralNetSingleOutput(torch.nn.Module):
@ -147,9 +147,9 @@ class NeuralNetUserAnnotateIntermediateTensor(torch.nn.Module):
def forward(self, input1, input2):
model_input = input1 + input2
out = self.fc1(model_input)
out = _InspectActivation.apply("fc1_out", None, GlobalSubscriberManager.get_run_context(), out)
out = inspect_activation("fc1_out", out)
out = self.relu(out)
out = _InspectActivation.apply("relu_out", None, GlobalSubscriberManager.get_run_context(), out)
out = inspect_activation("relu_out", out)
out = self.fc2(out)
return out