From d90afc697be8fced404771f0eb64da8ce54d467d Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 25 Aug 2023 00:15:22 +0800 Subject: [PATCH] Introduce ZeROOffloadSubscriber for ORTModule (#17006) ### Introduce ZeROOffloadSubscriber for ORTModule As part of the work: integrate ORTModule with DeepSpeed stage3, this PR mainly focus on moving original PyTorch-based (leveraging hooks) param partition/offload implementation to ORTModule compatible implementation. Changes include: 1. Refactor `SubscriberBase`/`SubcriberManager` to support pre-forward/post_forward hooks. 2. Implement new `ZeROOffloadSubscriber` by re-using DeepSpeed hook function as much as possible. Since all hook functions are defined in `DeepSpeedZeRoOffload._register_hooks_recursively` and `DeepSpeedZeRoOffload.setup_zero_stage3_hooks`, and the good thing is, the closure is not complex, all hooks are referencing the owning `DeepSpeedZeRoOffload` instance, so we can create new hook function with `FunctionType` by binding the owning `DeepSpeedZeRoOffload` instance, then call the new created function in subscriber's `pre_forward_module_apply_impl` and `post_forward_module_apply_impl` interfaces. 3. Monkey patch `DeepSpeedZeRoOffload.setup_zero_stage3_hooks` to register the `ZeROOffloadSubscriber` for the model, then we don't need change any code on the DeepSpeed repo (at least so far). 4. Fix the ATen embedding custom symbolic exporter function by tolerating weights size be (0) (changed by DeepSpeed zero stage 3). UT will be added once stage3 is fully supported. ### Motivation and Context --- docs/ORTModule_Convergence_Notes.md | 9 +- .../_custom_autograd_function_exporter.py | 15 +- .../_custom_autograd_function_runner.py | 2 +- .../ortmodule/_custom_op_symbolic_registry.py | 47 +- .../ortmodule/_graph_execution_manager.py | 22 +- .../python/training/ortmodule/_io.py | 1 + .../python/training/ortmodule/options.py | 7 + .../python/training/utils/__init__.py | 2 + .../python/training/utils/hooks/__init__.py | 27 +- .../utils/hooks/_statistics_subscriber.py | 92 +++- .../training/utils/hooks/_subscriber_base.py | 216 ++++++-- .../utils/hooks/_subscriber_manager.py | 379 ++++++-------- .../utils/hooks/_zero_offload_subscriber.py | 476 ++++++++++++++++++ .../python/training/utils/torch_type_map.py | 47 ++ .../test/python/orttraining_test_hooks.py | 6 +- 15 files changed, 1033 insertions(+), 315 deletions(-) create mode 100644 orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py create mode 100644 orttraining/orttraining/python/training/utils/torch_type_map.py diff --git a/docs/ORTModule_Convergence_Notes.md b/docs/ORTModule_Convergence_Notes.md index 8f54fd6b5a..791b6c32c9 100644 --- a/docs/ORTModule_Convergence_Notes.md +++ b/docs/ORTModule_Convergence_Notes.md @@ -83,12 +83,13 @@ Arguments: inspector node affects memory peak causing the original recipe run to fail with OOM. - `bucket_size`: the size of the bucket to split the statistic calculation. -### 2.2 Use `_InspectActivation` to collect intermediate tensors in a `nn.Module` forward() +### 2.2 Use `inspect_activation` to collect intermediate tensors in a `nn.Module` forward() The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output tensors will be dumped, if you want to dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example: ```diff ++ from onnxruntime.training.utils import inspect_activation class BloomForCausalLM(BloomPreTrainedModel): def __init__(self, config: BloomConfig): ... @@ -98,10 +99,10 @@ class BloomForCausalLM(BloomPreTrainedModel): transformer_outputs = self.transformer(...) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) -+ lm_logits = _InspectActivation.apply("lm_logits", None, GlobalSubscriberManager.get_run_context(), lm_logits) ++ lm_logits = inspect_activation("lm_logits", lm_logits) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() -+ shift_logits = _InspectActivation.apply("shift_logits", None, GlobalSubscriberManager.get_run_context(), shift_logits) ++ shift_logits = inspect_activation("shift_logits", shift_logits) shift_labels = labels[..., 1:].contiguous() batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens @@ -113,7 +114,7 @@ class BloomForCausalLM(BloomPreTrainedModel): return loss ``` -Be noted, make sure the activation name (as the first argument of `_InspectActivation.apply`) is unique, otherwise +Be noted, make sure the activation name (as the first argument of `inspect_activation`) is unique, otherwise stat file using the activation name will be overwritten by the last write. The dumped data are stored in the `output_dir`. diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 8445934dcf..4c72b6d98a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -6,16 +6,17 @@ import sys from typing import Callable, ClassVar, Dict, Optional -import onnx import torch import torch.utils.checkpoint +from onnx import ModelProto from packaging import version from torch.onnx import symbolic_helper from onnxruntime.capi._pybind_state import register_miscellaneous_const_input, register_torch_autograd_function from onnxruntime.training import ortmodule +from onnxruntime.training.utils import pytorch_dtype_to_onnx -from ._custom_op_symbolic_registry import pytorch_type_to_onnx, wrap_custom_export_function +from ._custom_op_symbolic_registry import wrap_custom_export_function from ._fallback import ORTModuleONNXModelException, wrap_exception from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version @@ -168,7 +169,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): if call_type == "d": # Got a tensor variable. tensor_args.append(arg) - scalar_type = pytorch_type_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) input_tensor_types.append(scalar_type) input_tensor_ranks.append(arg.type().dim()) continue @@ -247,7 +248,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): output_tensor_ranks = [] for arg in n.outputs(): # Type of tensor's elements. - scalar_type = pytorch_type_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) output_tensor_types.append(scalar_type) output_tensor_ranks.append(arg.type().dim()) @@ -306,16 +307,14 @@ def _export_pt_1_10(g, n, *args, **kwargs): _export = wrap_custom_export_function(_export_pt_1_10) -def _post_process_after_export( - exported_model: onnx.ModelProto, enable_custom_autograd_function: bool -) -> onnx.ModelProto: +def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto: """Post process the exported model.""" if enable_custom_autograd_function: exported_model = _post_process_enabling_autograd_function(exported_model) return exported_model -def _post_process_enabling_autograd_function(exported_model: onnx.ModelProto) -> onnx.ModelProto: +def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: # Loop all PythonOp, append "_ctx" as the first output. index = 0 for node in exported_model.graph.node: diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 8eeaee5bdf..845c7d83c2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -232,7 +232,7 @@ def call_python_forward_function( for arg_index, (grad_flag, tensor_flag, arg) in enumerate(zip(requires_grad_flags, tensor_type_flags, args)): if tensor_flag: - # Assume it's a DLPack tensor# and convert it to PyTorch tensor. + # Assume it's a DLPack tensor and convert it to PyTorch tensor. # Note1: # If it's first-time kernel invocation, input_indices_to_save_in_ctx is None, we do the # copy for all tensor. Otherwise, we only copy the tensors whose indices are in diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index ac87dc6abf..0dd33d493b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -12,37 +12,10 @@ from packaging.version import Version from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args +from onnxruntime.training.utils import pytorch_dtype_to_onnx + from ._utils import get_runtime_pytorch_version -# Mapping from pytorch scalar type to onnx scalar type. -_CAST_PYTORCH_TO_ONNX = { - "Byte": torch.onnx.TensorProtoDataType.UINT8, - "Char": torch.onnx.TensorProtoDataType.INT8, - "Double": torch.onnx.TensorProtoDataType.DOUBLE, - "Float": torch.onnx.TensorProtoDataType.FLOAT, - "Half": torch.onnx.TensorProtoDataType.FLOAT16, - "Int": torch.onnx.TensorProtoDataType.INT32, - "Long": torch.onnx.TensorProtoDataType.INT64, - "Short": torch.onnx.TensorProtoDataType.INT16, - "Bool": torch.onnx.TensorProtoDataType.BOOL, - "ComplexFloat": torch.onnx.TensorProtoDataType.COMPLEX64, - "ComplexDouble": torch.onnx.TensorProtoDataType.COMPLEX128, - "BFloat16": torch.onnx.TensorProtoDataType.BFLOAT16, - # Not yet defined in torch. - # "Float8E4M3FN": torch.onnx.TensorProtoDataType.FLOAT8E4M3FN, - # "Float8E4M3FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E4M3FNUZ, - # "Float8E5M2": torch.onnx.TensorProtoDataType.FLOAT8E5M2, - # "Float8E5M2FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E5M2FNUZ, - "Undefined": torch.onnx.TensorProtoDataType.UNDEFINED, -} - - -def pytorch_type_to_onnx(scalar_type: str) -> torch.onnx.TensorProtoDataType: - try: - return torch.onnx.JitScalarType.from_name(scalar_type).onnx_type() - except AttributeError: - return _CAST_PYTORCH_TO_ONNX[scalar_type] - def wrap_custom_export_function(original_func: Callable) -> Callable: """This function is to wrap the custom export function to make sure it can be used by different versions of PyTorch. @@ -172,7 +145,7 @@ def cross_entropy_loss(g, node, logits, target, weight, reduction, ignore_index, weight_casted, ignore_index, reduction_s=reduction, - output_type_i=pytorch_type_to_onnx(output_type.scalarType()), + output_type_i=pytorch_dtype_to_onnx(output_type.scalarType()), outputs=2, ) output.setType(output_type) @@ -199,10 +172,16 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): output = g.op( "org.pytorch.aten::ATen", weight, indices, padding_idx, scale_grad_by_freq, sparse, operator_s="embedding" ) - indices_shape = _get_tensor_sizes(indices) - if indices_shape is not None and hasattr(weight.type(), "with_sizes"): - output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)]) - output.setType(output_type) + + try: + # Tolerant to the case when sizes of indices are not available or not usable (for example + # when DeepSpeed stage3 enabled, all weights size is (0), this will fail.) + indices_shape = _get_tensor_sizes(indices) + if indices_shape is not None and hasattr(weight.type(), "with_sizes"): + output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)]) + output.setType(output_type) + except IndexError: + output.setType(weight.type()) return output diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 7f38428486..2227b630ae 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -20,6 +20,7 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.training.utils import ORTModelInputOutputSchemaType +from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils from ._custom_autograd_function_exporter import _post_process_after_export @@ -140,6 +141,10 @@ class GraphExecutionManager(GraphExecutionInterface): register_triton_op_executor() + if self._runtime_options.enable_zero_stage3_support: + # Cannot toggle feature enabling/disabling after the first time enabled. + configure_ort_compatible_zero_stage3() + def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses @@ -281,7 +286,11 @@ class GraphExecutionManager(GraphExecutionInterface): # All required models have already been exported previously return False self._set_device_from_module(inputs, kwargs) - self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs) + + from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step + + with no_increase_global_step(): + self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs) if self._debug_options.save_onnx_models.save: self._onnx_models.save_exported_model( self._debug_options.save_onnx_models.path, @@ -311,7 +320,6 @@ class GraphExecutionManager(GraphExecutionInterface): # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) # Be noted: rank 0 log only is controlled by logger configured in _logger.py torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO - self._logger.info("Exporting the PyTorch model to ONNX...") # Setup dynamic axes for onnx model self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) @@ -327,6 +335,8 @@ class GraphExecutionManager(GraphExecutionInterface): # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs self._flattened_module._input_info = self._input_info + self._logger.info("Exporting the PyTorch model to ONNX...") + # Leverage cached model if available cache_dir = self._runtime_options.ortmodule_cache_dir if cache_dir: @@ -659,6 +669,14 @@ class GraphExecutionManager(GraphExecutionInterface): ) ) + feature_map.append( + ( + "ZeRO Stage3 Support", + self._runtime_options.enable_zero_stage3_support, + "Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0", + ) + ) + mode = "training" if self._export_mode == torch.onnx.TrainingMode.TRAINING else "inference" mode = f"{_logger.LogColor.UNDERLINE}{mode}{_logger.LogColor.ENDC}" diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 86a06fe683..18b965c549 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -539,6 +539,7 @@ def parse_outputs_for_onnx_export_and_extract_schema( output_names = None output_dynamic_axes = None is_deepcopy = False + logger.info("Running model forward to infer output schema and dynamic axes...") with torch.no_grad(): # Deepcopy inputs, since input values may change after model run. sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs) diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index f8d3a6e779..0eb6790d7a 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -288,6 +288,9 @@ class _RuntimeOptions: # Cache exported model self.ortmodule_cache_dir = "" + # Experimental features. + self.enable_zero_stage3_support = False # Once enabled, cannot be disabled. + # Override the feature config if it exists in os env. self._override_from_env_vars() @@ -365,3 +368,7 @@ class _RuntimeOptions: if "ORTMODULE_CACHE_DIR" in os.environ: self._logger.info("ORTModule cache optimization is ON.") self.ortmodule_cache_dir = os.getenv("ORTMODULE_CACHE_DIR") + + # Experimental features. + if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1: + self.enable_zero_stage3_support = True diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index c88c744ca9..acf2698d55 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -9,6 +9,7 @@ from onnxruntime.training.utils.torch_io_helper import ( extract_data_and_schema, unflatten_data_using_schema, ) +from onnxruntime.training.utils.torch_type_map import pytorch_dtype_to_onnx __all__ = [ "PrimitiveType", @@ -16,4 +17,5 @@ __all__ = [ "ORTModelInputOutputSchemaType", "extract_data_and_schema", "unflatten_data_using_schema", + "pytorch_dtype_to_onnx", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/__init__.py b/orttraining/orttraining/python/training/utils/hooks/__init__.py index 91e919b1c5..89c0d44abb 100644 --- a/orttraining/orttraining/python/training/utils/hooks/__init__.py +++ b/orttraining/orttraining/python/training/utils/hooks/__init__.py @@ -3,14 +3,35 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- + +import torch + __all__ = [ "StatisticsSubscriber", "GlobalSubscriberManager", - "_InspectActivation", + "inspect_activation", + "ZeROOffloadSubscriber", + "configure_ort_compatible_zero_stage3", ] -from ._statistics_subscriber import StatisticsSubscriber -from ._subscriber_manager import SubscriberManager, _InspectActivation +from ._statistics_subscriber import StatisticsSubscriber, _InspectActivation +from ._subscriber_manager import SubscriberManager +from ._zero_offload_subscriber import ZeROOffloadSubscriber, configure_ort_compatible_zero_stage3 # Define a global uninitialized subscriber manager for usage where it is needed by different Python files. GlobalSubscriberManager = SubscriberManager() + + +def inspect_activation(activation_name: str, tensor: torch.Tensor) -> torch.Tensor: + for sub in GlobalSubscriberManager._subscribers: + if isinstance(sub, StatisticsSubscriber): + return _InspectActivation.apply( + activation_name, + None, + GlobalSubscriberManager.get_run_context(), + tensor, + sub.module_post_forward_impl, + sub.module_pre_backward_impl, + ) + + raise RuntimeError("StatisticsSubscriber is not registered, cannot inspect activation.") diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index 0dd06eee13..6c8027b2fe 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -7,11 +7,90 @@ import os import shutil import warnings from pathlib import Path -from typing import Union +from typing import List, Optional, Tuple, Union +import onnx import torch -from ._subscriber_base import SubscriberBase +from ._subscriber_base import RuntimeStates, SubscriberBase + + +class _InspectActivation(torch.autograd.Function): + """ + This class is used to run the subscriber's forward and backward functions. + The function will be called by two kinds of callers: + 1. SubscriberManager calls it for each registered nn.Module. + 2. Users who want to inspect the activation tensor at any place of model definition code. + """ + + @staticmethod + def forward( + ctx, + activation_name: str, + module_idx: Optional[int], + run_ctx: RuntimeStates, + input_tensor: torch.Tensor, + module_post_forward, + module_pre_backward, + ): + """ + Args: + ctx: context object to store intermediate information. + activation_name: the name of the activation tensor. + module_idx: unit id of the module (address of the module instance). + run_ctx: runtime context. + For call case 2 - need retrieve the runtime state from GlobalSubscriberManager. + input_tensor: the activation tensor. + + Make sure there is a same number of `tensor` type inputs and outputs. + This is enforced by ORT's PythonOp's schema check. + """ + depth = -1 + if module_idx is not None: + depth = run_ctx.global_states.module_index_to_depth[module_idx] + + input_tensor_copied = None + if input_tensor is None or not isinstance(input_tensor, torch.Tensor): + input_tensor_copied = input_tensor + else: + input_tensor_copied = input_tensor.detach().clone() + + ctx.current_step = run_ctx.global_states.execution_step + ctx.name = activation_name + ctx.id = module_idx + ctx.depth = depth + ctx.module_pre_backward = module_pre_backward + + module_post_forward(input_tensor_copied, depth, activation_name, ctx.current_step) + + return input_tensor.detach() if input_tensor is not None else None + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + val = None + if grad_output is None or not isinstance(grad_output, torch.Tensor): + val = grad_output + else: + val = grad_output.detach().clone() + + ctx.module_pre_backward(val, ctx.depth, ctx.name, ctx.current_step) + + return ( + None, + None, + None, + grad_output.detach() if grad_output is not None else None, + None, + None, + ) + + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes class StatisticsSubscriber(SubscriberBase): @@ -68,6 +147,15 @@ class StatisticsSubscriber(SubscriberBase): "Set override_output_dir=True for StatisticsSubscriber if this is the intention." ) + def post_forward_tensor_apply_impl( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + module_index = run_rtx.global_states.module_to_module_index[module] + name = f"{module.__class__.__name__}_{module_index}_{tensor_index}th_output" + return _InspectActivation.apply( + name, module_index, run_rtx, tensor, self.module_post_forward_impl, self.module_pre_backward_impl + ) + def module_post_forward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int): output_file_path = os.path.join(f"{self._output_dir}", f"step_{step}") return self._summarize_activations(activation, depth, name, output_file_path, True) diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py index 572f2fb99b..1b9a6fc91e 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py @@ -5,34 +5,54 @@ import sys -from typing import Union +from typing import Optional, Tuple import torch +from onnxruntime.training.utils import ORTModelInputOutputType + + +class RuntimeStates: + """ + A data struct holding states for runtime context. + > Global states that are one-time collected during model hook registration. A global execution step is + also initialized to reflect how many steps have been executed, it will get updated after each step + completes its forward path. + """ + + class _GlobalStates: + def __init__(self): + # Used to track current execution step, e.g. how many forward/backward path is called. + self.execution_step = 0 + # Used to store the depth of each module, which indicate the indentation level of the module. + self.module_index_to_depth = {} + # Used to store the unique id of each sequential activation. + self.module_to_module_index = {} + + def __init__(self): + self.global_states = RuntimeStates._GlobalStates() + class SubscriberBase: """ Base class for all module hook subscribers. - Currently, the hook here only means post-forward hook and pre-backward hook. - A module hook subscriber is a class that implements the `module_post_forward_impl` and `module_pre_backward_impl` - function. - > The post_forward hook is called after the activation is generated in the forward path. - > The pre_backward hook is called before the activation gradient is computed. + A module hook subscriber is a class that allow define custom actions to be executed during the nn.Module's hooks. + Two types of APIs can be used to define custom actions: + 1. Module level interfaces: + pre_forward_module_apply - called inside the nn.Module's pre-forward hook. + post_forward_module_apply - called inside the nn.Module's post-forward hook. + post_forward_outmost_module_apply - called inside the nn.Module's post-forward hook, but only for the outmost module. + 2. Tensor level interfaces: + pre_forward_tensor_apply - called inside the nn.Module's pre-forward hook, for each input tensor. + post_forward_tensor_apply - called inside the nn.Module's post-forward hook, for each output tensor. - The post_forward path: - Module_A generates activation tensor_a --> Post forward hook (calling subscribers' forward one by one) --> - Module_B generates activation tensor_b --> ... - - The pre_backward path: - Module_B backward run, tensor_b's gradient is computed as tensor_b_grad --> - Pre-backward hook (calling subscribers' backward one by one) --> - Module_A backward run, tensor_a's gradient is computed as tensor_a_grad - - Be noted: the "Pre"/"Post" is described from the perspective of Module_A. + For ORT runs, tensor's flows are important, that's the reason we have tensor input as function input, + and tensor output as function output for all the APIs. + With this, the overall flow can be traced as a data flow graph (DAG). """ - def __init__(self, start_step: Union[None, int], end_step: Union[None, int]): + def __init__(self, start_step: Optional[int], end_step: Optional[int]): """ Steps in [start_step, end_step) will run the subscriber's actions, and other steps will skip. If start_step is None, 0 is given; if end_step is None, sys.maxsize is given. @@ -40,34 +60,152 @@ class SubscriberBase: self._start_step: int = start_step if start_step is not None else 0 self._end_step: int = end_step if end_step is not None else sys.maxsize - def module_post_forward(self, activation: torch.Tensor, depth: int, name: str, step: int): - """ - This function will be run after the torch Module forward is completed. + def pre_forward_module_apply( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is called inside the nn.Module's pre-forward hook. Args: - activation: Tensor to be inspected. - depth: The indent level of the torch Module generating `activation`. - name: The unique name for the `activation`. - step: Current execution step. - """ - if self._start_step <= step < self._end_step: - self.module_post_forward_impl(activation, depth, name, step) + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + args (ORTModelInputOutputType): The positional arguments that are passed to the module's pre-forward hook. + kwargs (ORTModelInputOutputType): The keyword arguments that are passed to the module's pre-forward hook. + + Returns: + Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated args and kwargs. - def module_pre_backward(self, activation: torch.Tensor, depth: int, name: str, step: int): """ - This function will be run before the torch Module backward run. + if self._need_skip_step(run_rtx.global_states.execution_step): + return args, kwargs + + updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_rtx, module, args, kwargs) + return updated_args, updated_kwargs + + def pre_forward_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + return args, kwargs + + def pre_forward_tensor_apply( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + """This function is called inside the nn.Module's pre-forward hook. Args: - activation: Tensor to be inspected. - depth: The indent level of the torch Module generating `activation`. - name: The unique name for the `activation`. - step: Current execution step. + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + tensor_index (int): The index of the tensor in the input tensor list. + tensor (torch.Tensor): The tensor is one of module's forward inputs. """ - if self._start_step <= step < self._end_step: - self.module_pre_backward_impl(activation, depth, name, step) + if self._need_skip_step(run_rtx.global_states.execution_step): + return tensor - def module_post_forward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int): - raise NotImplementedError() + return self.pre_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor) - def module_pre_backward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int): - raise NotImplementedError() + def pre_forward_tensor_apply_impl( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + return tensor + + def post_forward_module_apply( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is called inside the nn.Module's post-forward hook. + + Args: + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward + hook as input. + outputs (ORTModelInputOutputType): The outputs arguments that are passed to the module's post-forward + hook as input. + + Returns: + Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs. + """ + if self._need_skip_step(run_rtx.global_states.execution_step): + return args, outputs + + return self.post_forward_module_apply_impl(run_rtx, module, args, outputs) + + def post_forward_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + return args, outputs + + def post_forward_tensor_apply( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + """This function is called inside the nn.Module's post-forward hook. + + Args: + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + tensor_index (int): The index of the tensor in the output tensor list. + tensor (torch.Tensor): The tensor is one of module's forward outputs. + + Returns: + torch.Tensor: Updated tensor. + """ + if self._need_skip_step(run_rtx.global_states.execution_step): + return tensor + + return self.post_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor) + + def post_forward_tensor_apply_impl( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + return tensor + + def post_forward_outmost_module_apply( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is called inside the outmost nn.Module's post-forward hook. + + Args: + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward + hook as input. + outputs (ORTModelInputOutputType): The outputs arguments that are passed to the module's post-forward + hook as input. + + Returns: + Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs. + """ + if self._need_skip_step(run_rtx.global_states.execution_step): + return args, outputs + + return self.post_forward_outmost_module_apply_impl(run_rtx, module, args, outputs) + + def post_forward_outmost_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + return args, outputs + + def _need_skip_step(self, current_step: int) -> bool: + return current_step < self._start_step or current_step >= self._end_step diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index 72208b7228..db38f58d8f 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -3,114 +3,31 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from collections import abc -from typing import Callable, List, Optional, Tuple, Union + +import inspect +from contextlib import contextmanager +from typing import List, Optional, Set, Tuple, Union import onnx import torch -from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.utils import extract_data_and_schema, unflatten_data_using_schema -from ._subscriber_base import SubscriberBase +from ._subscriber_base import RuntimeStates, SubscriberBase + +ORT_NO_INCREASE_GLOBAL_STEP = [False] -class _RuntimeStates: +@contextmanager +def no_increase_global_step(): + """During ONNX model export phase, forward run is triggered, but we don't want to increase the global step, then + Then the first iteration run will still start with 0, aligned with PyTorch's first iteration run. """ - A data struct holding states for runtime context. Tho kinds of states are included: - > Global states that are one-time collected during model hook registration. A global execution step is - also initialized to reflect how many steps have been executed, it will get updated after each step - completes its forward path. - > Intra-execution step states, initialized and cleaned up intended only for current execution step. - Usually, it carries intermediate information during the model execution. - """ - - class _GlobalStates: - def __init__(self): - # Used to track current execution step, e.g. how many forward/backward path is called. - self.execution_step = 0 - # Used to store the depth of each module, which indicate the indentation level of the module. - self.module_index_to_depth = {} - # Used to store the unique id of each sequential activation. - self.module_to_module_index = {} - - self.subscribers = set() - - class _ExecutionStepStates: - def __init__(self): - # Used to store the activation tensor names, if already handled, then skipped. - # Need to clear after each step. - self.observed_activation_names = {} - - def __init__(self): - self.global_states = _RuntimeStates._GlobalStates() - self.reset_step_states() - - def reset_step_states(self): - self.execution_step_states = _RuntimeStates._ExecutionStepStates() - - -class _InspectActivation(torch.autograd.Function): - """ - This class is used to run the subscriber's forward and backward functions. - The function will be called by two kinds of callers: - 1. SubscriberManager calls it for each registered nn.Module. - 2. Users who want to inspect the activation tensor at any place of model definition code. - """ - - @staticmethod - def forward( - ctx, activation_name: str, module_idx: Optional[int], run_ctx: _RuntimeStates, input_tensor: torch.Tensor - ): - """ - Args: - ctx: context object to store intermediate information. - activation_name: the name of the activation tensor. - module_idx: - For call case 1 - the unique id of the module that the activation belongs to, it is detected by the - SubscriberManager automatically. - For call case 2 - e.g, _InspectActivation is called by users (NOT by SubscriberManager), module_idx can - be None. - run_ctx: runtime context. - For call case 2 - need retrieve the runtime state from GlobalSubscriberManager. - input_tensor: the activation tensor. - - Make sure there is a same number of `tensor` type inputs and outputs. - This is enforced by ORT's PythonOp's schema check. - """ - depth = -1 - if module_idx is not None: - depth = run_ctx.global_states.module_index_to_depth[module_idx] - - input_tensor_copied = None - if input_tensor is None or not isinstance(input_tensor, torch.Tensor): - input_tensor_copied = input_tensor - else: - input_tensor_copied = input_tensor.detach().clone() - - ctx.current_step = run_ctx.global_states.execution_step - ctx.name = activation_name - ctx.id = module_idx - ctx.depth = depth - ctx.subscribers = run_ctx.global_states.subscribers - - # Run subscribers sequentially. - for subscriber in run_ctx.global_states.subscribers: - subscriber.module_post_forward(input_tensor_copied, depth, activation_name, ctx.current_step) - - return input_tensor.detach() if input_tensor is not None else None - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - val = None - if grad_output is None or not isinstance(grad_output, torch.Tensor): - val = grad_output - else: - val = grad_output.detach().clone() - - for subscriber in ctx.subscribers: - subscriber.module_pre_backward(val, ctx.depth, ctx.name, ctx.current_step) - - return None, None, None, grad_output.detach() if grad_output is not None else None + try: + ORT_NO_INCREASE_GLOBAL_STEP[0] = True + yield + finally: + ORT_NO_INCREASE_GLOBAL_STEP[0] = False @staticmethod def infer_shape( @@ -122,8 +39,7 @@ class _InspectActivation(torch.autograd.Function): class _IncrementStep(torch.autograd.Function): - """ - This class is used to manage the global execution step, e.g. + """This class is used to manage the global execution step, e.g. global step increment by one, once a full forward path is completed and the state clear. This autograd Function is registered as a post-forward hook to the root module. So once the root @@ -132,45 +48,24 @@ class _IncrementStep(torch.autograd.Function): """ @staticmethod - def forward(ctx, run_ctx: _RuntimeStates, input_tensor: torch.Tensor): - """ - Make sure there is a same number of `tensor` inputs and outputs. + def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: + """Make sure there is the same number of `tensor` inputs and outputs. This is enforced by ORT's PythonOp's schema check. """ ctx.current_step = run_ctx.global_states.execution_step ctx.run_ctx = run_ctx - # We cannot do the step incremental here. Imagine the outside-most module has multiple outputs, - # we need to increase the step only at the very last output handling. - # We avoid the complexity to probe the last output handling, and instead, we assume once - # the very first backward of the outside-most module is called, then the forward pass MUST be completed. + if ctx.current_step >= 0: + print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") - # Be noted: it is not safe to register _IncrementStep only for one of the outputs of the outside-most module, - # because we are not sure which output branch is executed earlier, for example. - # OuterMostModuleOutputs - # / \ - # OuterMostModuleOutputs_0_0th_output OuterMostModuleOutputs_0_1th_output - # | | - # PythonOp(_InspectActivation) PythonOp(_InspectActivation) - # | | - # PythonOp(_IncrementStep) graph output - # | - # graph output - # The PythonOp(_InspectActivation) (who relies on global step) after 1th output is possible - # to run before or after PythonOp(_IncrementStep), so increasing the step is not safe. + if ORT_NO_INCREASE_GLOBAL_STEP[0] is False: + ctx.run_ctx.global_states.execution_step += 1 - return input_tensor.detach() if isinstance(input_tensor, torch.Tensor) else input_tensor + return tuple(t.detach().requires_grad_(t.requires_grad) for t in input_tensor_list) @staticmethod - def backward(ctx, grad_output: torch.Tensor): - # In case there are multiple backward calls for multiple outputs of the outside-most module. - if ctx.current_step == ctx.run_ctx.global_states.execution_step: - if ctx.current_step >= 0: - print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") - ctx.run_ctx.global_states.execution_step += 1 - ctx.run_ctx.reset_step_states() - - return None, grad_output.detach() if isinstance(grad_output, torch.Tensor) else grad_output + def backward(ctx, *grad_output: Tuple[Optional[torch.Tensor], ...]) -> Tuple[Optional[torch.Tensor], ...]: + return (None, *tuple(g for g in grad_output)) @staticmethod def infer_shape( @@ -182,23 +77,25 @@ class _IncrementStep(torch.autograd.Function): class SubscriberManager: - """ - This class is used to manage all the subscribers and register the post-forward hook to the root module. - `subscribe()` is used to register a list of subscribers. + """This class is used to manage all the subscribers and register subscribers' custom actions as PyTorch hooks + to the nn.Modules. - Currently, the hook handled here is post forward hook for nn.Module. The hook is registered for all nn.Modules - recursively. Each hook inserts a PythonOp for every tensor output generated by the corresponding module. - Each subscriber implementation is called in the PythonOp's forward function, and backward function. + For the module-level/tensor_level custom actions defined by subscribers, they are registered as corresponding + PyTorch hooks in the sequence of the subscribers' registration order. There is one special handling for global step increment and state clear. A post-forward hook is registered for the outside-most module, which is the root module. In that hook, _IncrementStep is called, which will - increase the step by 1 once the very first time its backward is called (check _IncrementStep for details). + increase the step by 1 once the post forward hook is called if running without no_increase_global_step(). + `no_increase_global_step` is used to skip the step increment during ONNX model export. """ def __init__(self): - self._run_ctx: _RuntimeStates = _RuntimeStates() + self._run_ctx = RuntimeStates() + self._subscribers: Set[SubscriberBase] = set() + self._pre_forward_hooks = [] + self._post_forward_hooks = [] - def subscribe(self, module: Union[torch.nn.Module, ORTModule], subscribers: List[SubscriberBase]): + def subscribe(self, module: torch.nn.Module, subscribers: List[SubscriberBase]): """ The API is called externally to register hooks that are implicitly defined by subscribers. Each time all global states will be cleaned up once called. @@ -207,52 +104,110 @@ class SubscriberManager: raise ValueError("module must be a torch.nn.Module instance") self._reset_all_states() + self._subscribers.clear() - if isinstance(module, ORTModule): - module = module.module + try: + # Put the import here to avoid the module level dependency on onnxruntime.training.ortmodule + from onnxruntime.training.ortmodule import ORTModule + + if isinstance(module, ORTModule): + module = module.module + except ImportError: + pass for subscriber in subscribers: if not isinstance(subscriber, SubscriberBase): raise ValueError("subscriber must be a SubscriberBase instance") - self._run_ctx.global_states.subscribers.add(subscriber) + self._subscribers.add(subscriber) self._initialize(module) - def get_run_context(self) -> _RuntimeStates: + def get_subscriber(self, subscriber_type: type) -> SubscriberBase: + for subscriber in self._subscribers: + if isinstance(subscriber, subscriber_type): + return subscriber + raise RuntimeError(f"Subscriber {subscriber_type} is not registered.") + + def get_run_context(self) -> RuntimeStates: return self._run_ctx def _reset_all_states(self): - self._run_ctx = _RuntimeStates() + self._pre_forward_hooks.clear() + self._post_forward_hooks.clear() + self._run_ctx = RuntimeStates() def _initialize(self, module: torch.nn.Module): - """ - Register hooks for the specified module. - """ - if len(self._run_ctx.global_states.subscribers) == 0: + """Register hooks for the specified module.""" + if len(self._subscribers) == 0: raise RuntimeError("No subscribers are registered.") + def _pre_forward_outmost_module_hook(module, module_inputs): + # This check is to support the case where module is first registered in the subscriber manager, + # then the module and hook are copied, when new module instance runs to the hook, the global states + # are not reset, so the logic depends on the global states will fail. So in the outer-most pre-forward hook + # we reset the global states. + + # Be noted, the first run anyway will run in PyTorch. + if module not in self._run_ctx.global_states.module_to_module_index: + import warnings + + warnings.warn( + "Initialize global states for the first time, this should only happen once for each outmost module." + ) + self._initialize_one_time_global_states(module) + return module_inputs + + module.register_forward_pre_hook(_pre_forward_outmost_module_hook) + next_module_index = [0] - # Register post forward hook for every module, inside the hook, we loop every tensor output of the module, - # and wrap it with an autograd Function called _InspectActivation (which takes in a tensor and returns the same - # tensor). In this way, we keep ORT and PyTorch run have the same boundary to check activation equality. self._register_hooks_recursively(module, 1, next_module_index) # Register post forward hook for the outside-most module, then we increase the dump step. - # Be noted, if backward is not triggered, the global dump step remains the original number, - # which means the subsequent run will override the previous dump files. This indeed happens to imagine ORTModule - # firstly export graph (run the forward only), after the gradient graph is built, another forward+backward is - # triggered, override the previous dump files. - def _post_forward_outmost_module_hook(module, _, module_outputs): - def _apply_to_tensors_func(_, outputs): - return _IncrementStep.apply(self._run_ctx, outputs) + def _post_forward_outmost_module_hook(module, module_inputs, module_outputs): + # Call post outmost module forward custom actions for subscribers + for sub in self._subscribers: + module_inputs, module_outputs = sub.post_forward_outmost_module_apply( + self._run_ctx, module, module_inputs, module_outputs + ) - return self._apply_function_to_tensors(module, module_outputs, _apply_to_tensors_func) + flatten_output_tensor_list, output_schema = extract_data_and_schema(module_outputs) + output_tensors = _IncrementStep.apply(self._run_ctx, *flatten_output_tensor_list) + restored_outputs = unflatten_data_using_schema(output_tensors, output_schema) + + return restored_outputs module.register_forward_hook(_post_forward_outmost_module_hook) + def _initialize_one_time_global_states(self, module: torch.nn.Module): + def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: List[int]): + """ + Called to register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s, + this function is called recursively by passing in `next_module_index` - a list of int to maintain a + global incremental unique module id. + + Args: + module: torch.nn.Module to register hook. + depth: the indent of the module compared with the outside-most Module. + next_module_index: list of int, carrying a global unique module index that can be used next. + """ + module_index = next_module_index[0] + module.id = module_index # STAGE3WARN: needed by DeepSpeed + self._run_ctx.global_states.module_index_to_depth[module_index] = depth + self._run_ctx.global_states.module_to_module_index[module] = module_index + + for child in module.children(): + if ( + isinstance(child, torch.nn.Module) + and child not in self._run_ctx.global_states.module_to_module_index + ): + next_module_index[0] += 1 + _reset_recursively(child, depth + 1, next_module_index) + + next_module_index = [0] + _reset_recursively(module, 1, next_module_index) + def _register_hooks_recursively(self, module: torch.nn.Module, depth: int, next_module_index: List[int]): - """ - Called to register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s, + """Register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s, this function is called recursively by passing in `next_module_index` - a list of int to maintain a global incremental unique module id. @@ -262,6 +217,7 @@ class SubscriberManager: next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] + module.id = module_index # STAGE3WARN: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index @@ -270,69 +226,54 @@ class SubscriberManager: next_module_index[0] += 1 self._register_hooks_recursively(child, depth + 1, next_module_index) - def _post_forward_module_hook(module, _, module_outputs): - if module in self._run_ctx.global_states.module_to_module_index and isinstance(module, torch.nn.Module): - module_index = self._run_ctx.global_states.module_to_module_index[module] + def _pre_forward_module_with_kwargs_hook(module, module_inputs, kwargs): + # Module level hook + for sub in self._subscribers: + module_inputs, kwargs = sub.pre_forward_module_apply(self._run_ctx, module, module_inputs, kwargs) - def _apply_to_tensors_func(index, activation_tensor): - name = f"{module.__class__.__name__}_{module_index}_{index}th_output" - if name not in self._run_ctx.execution_step_states.observed_activation_names: - self._run_ctx.execution_step_states.observed_activation_names[name] = True - return _InspectActivation.apply(name, module_index, self._run_ctx, activation_tensor) + # Tensor level hook + flatten_positional_input_tensor_list, input_schema = extract_data_and_schema(module_inputs) + flatten_keyword_input_tensor_list, keyword_input_schema = extract_data_and_schema(kwargs) - return activation_tensor + for sub in self._subscribers: + tensor_list = [] + for tensor_index, tensor in enumerate(flatten_positional_input_tensor_list): + tensor_list.append(sub.pre_forward_tensor_apply(self._run_ctx, module, tensor_index, tensor)) + flatten_positional_input_tensor_list = tensor_list - return self._apply_function_to_tensors(module, module_outputs, _apply_to_tensors_func) - return module_outputs + tensor_list = [] + for tensor_index, tensor in enumerate(flatten_keyword_input_tensor_list): + tensor_list.append(sub.pre_forward_tensor_apply(self._run_ctx, module, tensor_index, tensor)) + flatten_keyword_input_tensor_list = tensor_list - module.register_forward_hook(_post_forward_module_hook) + module_inputs = unflatten_data_using_schema(flatten_positional_input_tensor_list, input_schema) + kwargs = unflatten_data_using_schema(flatten_keyword_input_tensor_list, keyword_input_schema) - def _is_builtin_type(self, obj): - # https://stackoverflow.com/a/17795199 - return obj.__class__.__module__ in ["__builtin__", "builtins"] + return module_inputs, kwargs - def _apply_function_to_tensors(self, module: torch.nn.Module, data, func: Callable): - """ - Apply func to all tensors in the given object. + def _pre_forward_module_hook(module, module_inputs): + return _pre_forward_module_with_kwargs_hook(module, module_inputs, {}) - Args: - module: the module that generates the tensors. - data: the object that contains activation tensors. - func: the function to apply to the tensors. - """ - tensor_output_idx: List[int] = [0] + def _post_forward_module_hook(module, module_inputs, module_outputs): + # Module level hook + for sub in self._subscribers: + _, module_outputs = sub.post_forward_module_apply(self._run_ctx, module, module_inputs, module_outputs) - def _apply_to_tensors_by_flatten( - module: torch.nn.Module, - index_for_tensor_output: List[int], - outputs, - func: Callable, - ): - if isinstance(outputs, abc.Sequence): - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_by_flatten(module, index_for_tensor_output, output, func) - touched_outputs.append(touched_output) - return outputs.__class__(touched_outputs) + # Tensor level hook + flatten_output_tensor_list, output_schema = extract_data_and_schema(module_outputs) + for sub in self._subscribers: + tensor_list = [] + for tensor_index, tensor in enumerate(flatten_output_tensor_list): + tensor_list.append(sub.post_forward_tensor_apply(self._run_ctx, module, tensor_index, tensor)) + flatten_output_tensor_list = tensor_list - if isinstance(outputs, abc.Mapping): - # apply inplace to avoid recreating dict inherited objects - for key in outputs: - outputs[key] = _apply_to_tensors_by_flatten( - module, - index_for_tensor_output, - outputs[key], - func, - ) - return outputs + return unflatten_data_using_schema(flatten_output_tensor_list, output_schema) - if isinstance(outputs, torch.Tensor): - cur_id = index_for_tensor_output[0] - index_for_tensor_output[0] += 1 - return func(cur_id, outputs) - - if not self._is_builtin_type(outputs): - raise RuntimeError(f"Unknown type {type(outputs)}") - return outputs - - return _apply_to_tensors_by_flatten(module, tensor_output_idx, data, func) + # "with_kwargs" is not available for low versions of PyTorch. + if "with_kwargs" in inspect.signature(module.register_forward_pre_hook).parameters: + self._pre_forward_hooks.append( + module.register_forward_pre_hook(_pre_forward_module_with_kwargs_hook, with_kwargs=True) + ) + else: + self._pre_forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook)) + self._post_forward_hooks.append(module.register_forward_hook(_post_forward_module_hook)) diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py new file mode 100644 index 0000000000..3d42e172ee --- /dev/null +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -0,0 +1,476 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import ctypes +import inspect +import warnings +from collections import OrderedDict +from types import CodeType, FunctionType +from typing import Callable, Dict, List, Optional, Tuple, Union + +import onnx +import torch + +from onnxruntime.training.utils import ( + ORTModelInputOutputType, + extract_data_and_schema, + pytorch_dtype_to_onnx, + unflatten_data_using_schema, +) + +from ._subscriber_base import RuntimeStates, SubscriberBase + + +# Used to monkey patch the original function +# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 +def _setup_zero_stage3_ort_compatible_hooks(self): + self.hierarchy = 0 + + from onnxruntime.training.utils.hooks import SubscriberManager, ZeROOffloadSubscriber + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer + + # Each DeepSpeed engine has a separate subscriber manager. + self._offload_subscriber_manager = SubscriberManager() + self._offload_subscriber_manager.subscribe( + self.module, [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] + ) + self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) + self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) + + # Add top module to stack trace + global FWD_MODULE_STACK # noqa: PLW0602 + FWD_MODULE_STACK.append(self.module) + + +# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104 +# In the original logic, if bias is None, after export to ONNX, None becomes a constant, so backward op complains +# output count more than needed. +def _zero3_linear_wrap_ort_compatible(input, weight, bias=None): + from deepspeed.runtime.zero.linear import LinearFunctionForZeroStage3 + + return LinearFunctionForZeroStage3.apply(input, weight, bias) + + +class _ZeROOffloadOneTimeInitializer: + """Store the hook functions from DeepSpeed ZeRO offload. + + Hook functions code collected from DeepSpeed. + """ + + def __init__(self): + self._code_store: OrderedDict[str, CodeType] = {} + + def collect_code(self, function: Callable): + """Collect the function `CodeType`, which is the code object of the function.""" + code_obj = function.__code__ + for c in code_obj.co_consts: + if inspect.iscode(c): + self._code_store[c.co_name] = c + + +_zero_offload_one_time_initializer = None + +try: + # Have to import below explicitly, otherwise it complains about _apply_to_tensors_only not found. + # The hooks reference functions or classes in that file. + from deepspeed.runtime.zero.parameter_offload import * # noqa: F403 + from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload, _apply_to_tensors_only # noqa: F401 + from deepspeed.utils import instrument_w_nvtx # noqa: F401 + + # Used to collect the hook functions's code object from DeepSpeed ZeRO offload, this should be initialized only once. + if _zero_offload_one_time_initializer is None: + _zero_offload_one_time_initializer = _ZeROOffloadOneTimeInitializer() + _zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload._register_hooks_recursively) + _zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload.setup_zero_stage3_hooks) + + # This is the function to enable ORT ZeRO offload. + def configure_ort_compatible_zero_stage3(): + """Configure ZeRO stage3 to be ORT compatible. + + This function will overwrite the original DeepSpeed ZeRO stage3 hooks to make it ORT compatible. + """ + + # Only done once no matter how many times this function is called for different modules. + DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _setup_zero_stage3_ort_compatible_hooks + + from deepspeed.runtime.zero.linear import zero3_linear_wrap + + if torch.nn.functional.linear is zero3_linear_wrap: + torch.nn.functional.linear = _zero3_linear_wrap_ort_compatible + +except ImportError as e: + warnings.warn(f"DeepSpeed import error {e}") + + def configure_ort_compatible_zero_stage3(): + raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") + + +def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.parameter.Parameter]: + """Retrieve the parameters for this module. + + Logic adapted from + https://github.com/microsoft/DeepSpeed/blob/9d79cfd1e90cae9306dc1b5837d374b2c9489ac8/deepspeed/runtime/zero/partitioned_param_coordinator.py#L267 + """ + from deepspeed.runtime.zero.partitioned_param_coordinator import iter_params + + # Retrive the parameters that are not available for this module. + partitioned_params = [param for param in iter_params(module)] + + return partitioned_params + + +def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: + """Retrieve all the parameters that are offloaded.""" + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + all_offloaed_params = OrderedDict() + for name, param in module.named_parameters(): + if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + all_offloaed_params[name] = param + + return all_offloaed_params + + +class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): + """This function is a common bridge to call original PyTorch's + pre_forward_function and post_backward_function. + """ + + @staticmethod + def forward( + ctx, + module, + pre_forward_with_kwargs_function, + post_backward_function, + args_schema, + kwargs_schema, + args_tensor_count, + kwargs_tensor_count, + *tensor_list, + ): + """ + Args: + ctx: context object + module: the module to be called + pre_forward_with_kwargs_function: the function to be called before forward (PyTorch's pre_forward_function) + post_backward_function: the function to be called after backward (PyTorch's post_backward_function) + args_schema: the schema of the args, used to reconstruct the args in original form in + PyTorch's pre_forward_function's inputs. + kwargs_schema: the schema of the kwargs, used to reconstruct the kwargs in original form in + PyTorch's pre_forward_function's inputs. + args_tensor_count: the number of tensors in args. + kwargs_tensor_count: the number of tensors in kwargs. + tensor_list: the list of tensors, the first args_tensor_count tensors are args, the next + kwargs_tensor_count tensors are kwargs, the rest are the parameters for offload. + """ + args_tensors = tensor_list[:args_tensor_count] + kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count] + + args = unflatten_data_using_schema(args_tensors, args_schema) + kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) + + # We will re-retrieve the parameter tensors other than use the one passed in input (of size 0 for + # those partitioned params). + # This is required for ORT run because in ORT graph, the tensor of size 0 will always be size 0 + # (this step is not necessary for PyTorch run, because PyTorch will re-use the same tensor + # while .data got updated to full-sized data after pre_forward_with_kwargs_function is called). + partitioned_params = _get_params_for_current_module(module) + ctx.partitioned_params = partitioned_params + + f_ret = pre_forward_with_kwargs_function(module, args, kwargs) + + if f_ret is None: + updated_args, updated_kwargs = args, kwargs + else: + assert isinstance(f_ret, tuple) + updated_args, updated_kwargs = f_ret + + ctx.module = module + ctx.post_backward_function = post_backward_function + + updated_args_tensors, _ = extract_data_and_schema(updated_args) + updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs) + + rets = tuple(updated_args_tensors + updated_kwargs_tensors) + rets += tuple([p.detach().requires_grad_(p.requires_grad) for p in partitioned_params]) + + # PyTorch exporter does not support an empty list of tensors, so we have this check. + assert len(rets) != 0 + return rets + + @staticmethod + def backward(ctx, *grads): + updated_grads = grads + if ctx.post_backward_function is not None: + ret = ctx.post_backward_function(ctx.module, grads) + if ret is not None: + updated_grads = ret + + # TODO(pengwa) Update grad for partitioned parameters. + input_count = len(updated_grads) - len(ctx.partitioned_params) + zeros = [torch.zeros(0, dtype=p.dtype, device=p.device) for p in ctx.partitioned_params] + zero_grads = updated_grads[:input_count] + tuple(zeros) + + return (None, None, None, None, None, None, None, *zero_grads) + + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + input_pointer_scalars_attr_name = "input_pointer_scalars" + found = [attr for attr in node.attribute if attr.name == input_pointer_scalars_attr_name] + assert len(found) == 1 + input_pointer_scalars = found[0].ints + + # Restore the nn.Module from the pointer. + module = ctypes.cast(input_pointer_scalars[0], ctypes.py_object).value + + partitioned_params = _get_params_for_current_module(module) + tensor_output_shapes = tensor_input_shapes + tensor_output_dtypes = tensor_input_dtypes + start_offset = len(tensor_input_shapes) - len(partitioned_params) + for index, param in enumerate(partitioned_params): + tensor_output_shapes[start_offset + index] = list(param.ds_shape) + tensor_output_dtypes[start_offset + index] = pytorch_dtype_to_onnx(param.dtype) + assert len(tensor_output_shapes) == len(tensor_input_shapes) + assert len(tensor_output_dtypes) == len(tensor_input_dtypes) + + return tensor_output_shapes, tensor_output_dtypes + + +class ORTZeROOffloadPostForwardFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + module, + post_forward_function, + pre_backward_function, + output_schema, + *output_tensors, + ): + """ + Args: + ctx: context object + module: the module to be called + post_forward_function: the function to be called after forward (PyTorch's post_forward_function) + pre_backward_function: the function to be called before backward (PyTorch's pre_backward_function) + output_schema: the schema of the output, used to reconstruct the output in original form in + PyTorch's post_forward_function's inputs. + output_tensors: the list of tensors. + + """ + outputs = unflatten_data_using_schema(output_tensors, output_schema) + + # STAGE3WARN: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. + updated_outputs = post_forward_function(module, None, outputs) + + if updated_outputs is None: + updated_output_tensors = output_tensors + else: + updated_output_tensors, _ = extract_data_and_schema(updated_outputs) + + ctx.module = module + ctx.pre_backward_function = pre_backward_function + rets = [o.detach().requires_grad_(o.requires_grad) for o in updated_output_tensors] + return tuple(rets) + + @staticmethod + def backward(ctx, *grads): + updated_args = grads + if ctx.pre_backward_function is not None: + ret = ctx.pre_backward_function(ctx.module, grads) + if ret is not None: + updated_args = ret + return (None, None, None, None, *updated_args) + + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + + +class _ZeROOffloadFunctions: + def __init__(self, one_time_init: _ZeROOffloadOneTimeInitializer, offloader) -> None: + self._function_store: OrderedDict[str, FunctionType] = {} + self._one_time_init = one_time_init + for name, code in self._one_time_init._code_store.items(): + cell = self._create_closure_for_ds_hook_function(offloader) + self._function_store[name] = FunctionType(code, globals(), code.co_name, None, (cell,)) + + def get(self, name: str) -> FunctionType: + return self._function_store[name] + + def _create_closure_for_ds_hook_function(self, offloader): + # https://stackoverflow.com/questions/17395338/how-to-access-a-function-inside-a-function + def make_closure_cell(_self): + def nested(): + return _self + + return nested.__closure__[0] + + cell = make_closure_cell(offloader) + return cell + + +class ZeROOffloadSubscriber(SubscriberBase): + """This subscriber is used to enable ZeRO Offload feature in a way compatible with ORTModule.""" + + def __init__(self, offloader, one_time_init: _ZeROOffloadOneTimeInitializer, enable_debug_info: bool = False): + super().__init__(None, None) + self._offloader = offloader + self._functions = _ZeROOffloadFunctions(one_time_init, self._offloader) + self._enable_debug_info = enable_debug_info + + def pre_forward_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is a dispatcher to call DeepSpeed stage3 pre forward hooks in sequence. + + All hook functions can be retrieved from the function store, due to exporter only supports a list of tensors as + input and output for torch.autograd.Function, so we do flatten and unflatten here. + + """ + + args_tensors, args_schema = extract_data_and_schema(args) + kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) + + partitioned_params = _get_params_for_current_module(module) + + _pre_forward_module_hook = self._functions.get("_pre_forward_module_hook") + + args_tensor_count = len(args_tensors) + kwargs_tensor_count = len(kwargs_tensors) + + def _wrap_pre_forward_module_hook(module, args, kwargs): + rets = _pre_forward_module_hook(module, args) + updated_args, updated_kwargs = args, kwargs + if rets is not None: + updated_args = rets + + # STAGE3WARN: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. + module.ds_grads_remaining = 0 + return updated_args, updated_kwargs + + all_tensors = args_tensors + kwargs_tensors + partitioned_params + + self._check_all_tensor(all_tensors, module, "pre_forward_module_apply_impl input check") + + rets = ORTZeROOffloadPreForwardFunction.apply( + module, + _wrap_pre_forward_module_hook, + None, + args_schema, + kwargs_schema, + args_tensor_count, + kwargs_tensor_count, + *all_tensors, + ) + + self._check_all_tensor(rets, module, "pre_forward_module_apply_impl output check") + + updated_args_tensors = rets[:args_tensor_count] + updated_kwargs_tensors = rets[args_tensor_count : args_tensor_count + kwargs_tensor_count] + + updated_args = unflatten_data_using_schema(updated_args_tensors, args_schema) + updated_kwargs = unflatten_data_using_schema(updated_kwargs_tensors, kwargs_schema) + + _post_backward_module_hook = self._functions.get("_post_backward_module_hook") + # STAGE3WARN: Other part of _post_backward_module_hook can be traced correctly so we don't need to + # wrap with PythonOp. + updated_args = _post_backward_module_hook(module, updated_args) + + return updated_args, updated_kwargs + + def post_forward_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is a dispatcher to call DeepSpeed stage3 post forward hooks in sequence. + + All hook functions can be retrieved from function store, due to exporter only supports a list of tensors as + input and output for torch.autograd.Function, so we do flatten and unflatten here. + + """ + + outputs_tensors, outputs_schema = extract_data_and_schema(outputs) + + _post_forward_module_hook = self._functions.get("_post_forward_module_hook") + + def _wrap_post_forward_module_hook(module, input, outputs): + # STAGE3WARN: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. + from deepspeed.runtime.zero.partition_parameters import is_zero_param + + updated_outputs = _post_forward_module_hook(module, input, outputs) + if updated_outputs: + for updated_output in updated_outputs: + # restore zero param attributes if those get stripped by `backward_function` + if not is_zero_param(updated_output) and is_zero_param(outputs): + updated_output.ds_param_alias = outputs + return updated_outputs + else: + return outputs + + self._check_all_tensor(outputs_tensors, module, "post_forward_module_apply_impl input check") + + updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply( + module, _wrap_post_forward_module_hook, None, outputs_schema, *outputs_tensors + ) + + self._check_all_tensor(updated_outputs_tensors, module, "post_forward_module_apply_impl output check") + + assert len(updated_outputs_tensors) == len(outputs_tensors) + + # WARN: we assume updated_output_tensors can REUSE the outputs_schema. + updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) + + _pre_backward_module_hook = self._functions.get("_pre_backward_module_hook") + # STAGE3WARN: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN: part of the original _pre_backward_module_hook can be traced correctly so we moved them into + # _wrap_post_forward_module_hook above. + updated_outputs = _pre_backward_module_hook(module, None, updated_outputs) + + return args, updated_outputs + + def post_forward_outmost_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + outputs_tensors, outputs_schema = extract_data_and_schema(outputs) + + _end_of_forward_hook = self._functions.get("_end_of_forward_hook") + self._check_all_tensor(outputs_tensors, module, "post_forward_outmost_module_apply_impl input check") + + updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply( + module, _end_of_forward_hook, None, outputs_schema, *outputs_tensors + ) + + self._check_all_tensor(updated_outputs_tensors, module, "post_forward_outmost_module_apply_impl output check") + + assert len(updated_outputs_tensors) == len(outputs_tensors) + updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) + return args, updated_outputs + + def _check_all_tensor(self, tensor_list: Tuple[torch.Tensor], module: torch.nn.Module, name: str): + if not self._enable_debug_info: + return + + for t in tensor_list: + if not isinstance(t, torch.Tensor): + raise RuntimeError(f"{name} fail: {module.__class__.__name__}, input type: {type(t)}") diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py new file mode 100644 index 0000000000..699747723f --- /dev/null +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + + +from typing import Union + +import torch + +# Mapping from pytorch scalar type to onnx scalar type. +_CAST_PYTORCH_TO_ONNX = { + "Byte": [torch.onnx.TensorProtoDataType.UINT8, torch.uint8], + "Char": [torch.onnx.TensorProtoDataType.INT8, torch.int8], + "Double": [torch.onnx.TensorProtoDataType.DOUBLE, torch.double], + "Float": [torch.onnx.TensorProtoDataType.FLOAT, torch.float], + "Half": [torch.onnx.TensorProtoDataType.FLOAT16, torch.half], + "Int": [torch.onnx.TensorProtoDataType.INT32, torch.int], + "Long": [torch.onnx.TensorProtoDataType.INT64, torch.int64], + "Short": [torch.onnx.TensorProtoDataType.INT16, torch.short], + "Bool": [torch.onnx.TensorProtoDataType.BOOL, torch.bool], + "ComplexFloat": [torch.onnx.TensorProtoDataType.COMPLEX64, torch.complex64], + "ComplexDouble": [torch.onnx.TensorProtoDataType.COMPLEX128, torch.complex128], + "BFloat16": [torch.onnx.TensorProtoDataType.BFLOAT16, torch.bfloat16], + # Not yet defined in torch. + # "Float8E4M3FN": torch.onnx.TensorProtoDataType.FLOAT8E4M3FN, + # "Float8E4M3FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E4M3FNUZ, + # "Float8E5M2": torch.onnx.TensorProtoDataType.FLOAT8E5M2, + # "Float8E5M2FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E5M2FNUZ, + "Undefined": [torch.onnx.TensorProtoDataType.UNDEFINED, None], +} + + +_DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _CAST_PYTORCH_TO_ONNX.items()} + + +def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: + """Converts a pytorch dtype or scalar type string to an onnx dtype.""" + dtype = dtype_or_scalar_type + if isinstance(dtype, str): + if dtype not in _CAST_PYTORCH_TO_ONNX: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _CAST_PYTORCH_TO_ONNX[dtype][0] + + if dtype not in _DTYPE_TO_ONNX: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _DTYPE_TO_ONNX[dtype] diff --git a/orttraining/orttraining/test/python/orttraining_test_hooks.py b/orttraining/orttraining/test/python/orttraining_test_hooks.py index 4fb416e640..a58b3919c5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_hooks.py +++ b/orttraining/orttraining/test/python/orttraining_test_hooks.py @@ -8,7 +8,7 @@ import pytest import torch from onnxruntime.training.ortmodule import ORTModule -from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber, _InspectActivation +from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber, inspect_activation class NeuralNetSingleOutput(torch.nn.Module): @@ -147,9 +147,9 @@ class NeuralNetUserAnnotateIntermediateTensor(torch.nn.Module): def forward(self, input1, input2): model_input = input1 + input2 out = self.fc1(model_input) - out = _InspectActivation.apply("fc1_out", None, GlobalSubscriberManager.get_run_context(), out) + out = inspect_activation("fc1_out", out) out = self.relu(out) - out = _InspectActivation.apply("relu_out", None, GlobalSubscriberManager.get_run_context(), out) + out = inspect_activation("relu_out", out) out = self.fc2(out) return out