diff --git a/docs/ORTModule_Convergence_Notes.md b/docs/ORTModule_Convergence_Notes.md index 8f54fd6b5a..791b6c32c9 100644 --- a/docs/ORTModule_Convergence_Notes.md +++ b/docs/ORTModule_Convergence_Notes.md @@ -83,12 +83,13 @@ Arguments: inspector node affects memory peak causing the original recipe run to fail with OOM. - `bucket_size`: the size of the bucket to split the statistic calculation. -### 2.2 Use `_InspectActivation` to collect intermediate tensors in a `nn.Module` forward() +### 2.2 Use `inspect_activation` to collect intermediate tensors in a `nn.Module` forward() The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output tensors will be dumped, if you want to dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example: ```diff ++ from onnxruntime.training.utils import inspect_activation class BloomForCausalLM(BloomPreTrainedModel): def __init__(self, config: BloomConfig): ... @@ -98,10 +99,10 @@ class BloomForCausalLM(BloomPreTrainedModel): transformer_outputs = self.transformer(...) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) -+ lm_logits = _InspectActivation.apply("lm_logits", None, GlobalSubscriberManager.get_run_context(), lm_logits) ++ lm_logits = inspect_activation("lm_logits", lm_logits) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() -+ shift_logits = _InspectActivation.apply("shift_logits", None, GlobalSubscriberManager.get_run_context(), shift_logits) ++ shift_logits = inspect_activation("shift_logits", shift_logits) shift_labels = labels[..., 1:].contiguous() batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens @@ -113,7 +114,7 @@ class BloomForCausalLM(BloomPreTrainedModel): return loss ``` -Be noted, make sure the activation name (as the first argument of `_InspectActivation.apply`) is unique, otherwise +Be noted, make sure the activation name (as the first argument of `inspect_activation`) is unique, otherwise stat file using the activation name will be overwritten by the last write. The dumped data are stored in the `output_dir`. diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 8445934dcf..4c72b6d98a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -6,16 +6,17 @@ import sys from typing import Callable, ClassVar, Dict, Optional -import onnx import torch import torch.utils.checkpoint +from onnx import ModelProto from packaging import version from torch.onnx import symbolic_helper from onnxruntime.capi._pybind_state import register_miscellaneous_const_input, register_torch_autograd_function from onnxruntime.training import ortmodule +from onnxruntime.training.utils import pytorch_dtype_to_onnx -from ._custom_op_symbolic_registry import pytorch_type_to_onnx, wrap_custom_export_function +from ._custom_op_symbolic_registry import wrap_custom_export_function from ._fallback import ORTModuleONNXModelException, wrap_exception from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version @@ -168,7 +169,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): if call_type == "d": # Got a tensor variable. tensor_args.append(arg) - scalar_type = pytorch_type_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) input_tensor_types.append(scalar_type) input_tensor_ranks.append(arg.type().dim()) continue @@ -247,7 +248,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): output_tensor_ranks = [] for arg in n.outputs(): # Type of tensor's elements. - scalar_type = pytorch_type_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) output_tensor_types.append(scalar_type) output_tensor_ranks.append(arg.type().dim()) @@ -306,16 +307,14 @@ def _export_pt_1_10(g, n, *args, **kwargs): _export = wrap_custom_export_function(_export_pt_1_10) -def _post_process_after_export( - exported_model: onnx.ModelProto, enable_custom_autograd_function: bool -) -> onnx.ModelProto: +def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto: """Post process the exported model.""" if enable_custom_autograd_function: exported_model = _post_process_enabling_autograd_function(exported_model) return exported_model -def _post_process_enabling_autograd_function(exported_model: onnx.ModelProto) -> onnx.ModelProto: +def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: # Loop all PythonOp, append "_ctx" as the first output. index = 0 for node in exported_model.graph.node: diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 8eeaee5bdf..845c7d83c2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -232,7 +232,7 @@ def call_python_forward_function( for arg_index, (grad_flag, tensor_flag, arg) in enumerate(zip(requires_grad_flags, tensor_type_flags, args)): if tensor_flag: - # Assume it's a DLPack tensor# and convert it to PyTorch tensor. + # Assume it's a DLPack tensor and convert it to PyTorch tensor. # Note1: # If it's first-time kernel invocation, input_indices_to_save_in_ctx is None, we do the # copy for all tensor. Otherwise, we only copy the tensors whose indices are in diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index ac87dc6abf..0dd33d493b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -12,37 +12,10 @@ from packaging.version import Version from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args +from onnxruntime.training.utils import pytorch_dtype_to_onnx + from ._utils import get_runtime_pytorch_version -# Mapping from pytorch scalar type to onnx scalar type. -_CAST_PYTORCH_TO_ONNX = { - "Byte": torch.onnx.TensorProtoDataType.UINT8, - "Char": torch.onnx.TensorProtoDataType.INT8, - "Double": torch.onnx.TensorProtoDataType.DOUBLE, - "Float": torch.onnx.TensorProtoDataType.FLOAT, - "Half": torch.onnx.TensorProtoDataType.FLOAT16, - "Int": torch.onnx.TensorProtoDataType.INT32, - "Long": torch.onnx.TensorProtoDataType.INT64, - "Short": torch.onnx.TensorProtoDataType.INT16, - "Bool": torch.onnx.TensorProtoDataType.BOOL, - "ComplexFloat": torch.onnx.TensorProtoDataType.COMPLEX64, - "ComplexDouble": torch.onnx.TensorProtoDataType.COMPLEX128, - "BFloat16": torch.onnx.TensorProtoDataType.BFLOAT16, - # Not yet defined in torch. - # "Float8E4M3FN": torch.onnx.TensorProtoDataType.FLOAT8E4M3FN, - # "Float8E4M3FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E4M3FNUZ, - # "Float8E5M2": torch.onnx.TensorProtoDataType.FLOAT8E5M2, - # "Float8E5M2FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E5M2FNUZ, - "Undefined": torch.onnx.TensorProtoDataType.UNDEFINED, -} - - -def pytorch_type_to_onnx(scalar_type: str) -> torch.onnx.TensorProtoDataType: - try: - return torch.onnx.JitScalarType.from_name(scalar_type).onnx_type() - except AttributeError: - return _CAST_PYTORCH_TO_ONNX[scalar_type] - def wrap_custom_export_function(original_func: Callable) -> Callable: """This function is to wrap the custom export function to make sure it can be used by different versions of PyTorch. @@ -172,7 +145,7 @@ def cross_entropy_loss(g, node, logits, target, weight, reduction, ignore_index, weight_casted, ignore_index, reduction_s=reduction, - output_type_i=pytorch_type_to_onnx(output_type.scalarType()), + output_type_i=pytorch_dtype_to_onnx(output_type.scalarType()), outputs=2, ) output.setType(output_type) @@ -199,10 +172,16 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): output = g.op( "org.pytorch.aten::ATen", weight, indices, padding_idx, scale_grad_by_freq, sparse, operator_s="embedding" ) - indices_shape = _get_tensor_sizes(indices) - if indices_shape is not None and hasattr(weight.type(), "with_sizes"): - output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)]) - output.setType(output_type) + + try: + # Tolerant to the case when sizes of indices are not available or not usable (for example + # when DeepSpeed stage3 enabled, all weights size is (0), this will fail.) + indices_shape = _get_tensor_sizes(indices) + if indices_shape is not None and hasattr(weight.type(), "with_sizes"): + output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)]) + output.setType(output_type) + except IndexError: + output.setType(weight.type()) return output diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 7f38428486..2227b630ae 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -20,6 +20,7 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.training.utils import ORTModelInputOutputSchemaType +from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils from ._custom_autograd_function_exporter import _post_process_after_export @@ -140,6 +141,10 @@ class GraphExecutionManager(GraphExecutionInterface): register_triton_op_executor() + if self._runtime_options.enable_zero_stage3_support: + # Cannot toggle feature enabling/disabling after the first time enabled. + configure_ort_compatible_zero_stage3() + def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses @@ -281,7 +286,11 @@ class GraphExecutionManager(GraphExecutionInterface): # All required models have already been exported previously return False self._set_device_from_module(inputs, kwargs) - self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs) + + from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step + + with no_increase_global_step(): + self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs) if self._debug_options.save_onnx_models.save: self._onnx_models.save_exported_model( self._debug_options.save_onnx_models.path, @@ -311,7 +320,6 @@ class GraphExecutionManager(GraphExecutionInterface): # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) # Be noted: rank 0 log only is controlled by logger configured in _logger.py torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO - self._logger.info("Exporting the PyTorch model to ONNX...") # Setup dynamic axes for onnx model self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) @@ -327,6 +335,8 @@ class GraphExecutionManager(GraphExecutionInterface): # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs self._flattened_module._input_info = self._input_info + self._logger.info("Exporting the PyTorch model to ONNX...") + # Leverage cached model if available cache_dir = self._runtime_options.ortmodule_cache_dir if cache_dir: @@ -659,6 +669,14 @@ class GraphExecutionManager(GraphExecutionInterface): ) ) + feature_map.append( + ( + "ZeRO Stage3 Support", + self._runtime_options.enable_zero_stage3_support, + "Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0", + ) + ) + mode = "training" if self._export_mode == torch.onnx.TrainingMode.TRAINING else "inference" mode = f"{_logger.LogColor.UNDERLINE}{mode}{_logger.LogColor.ENDC}" diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 86a06fe683..18b965c549 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -539,6 +539,7 @@ def parse_outputs_for_onnx_export_and_extract_schema( output_names = None output_dynamic_axes = None is_deepcopy = False + logger.info("Running model forward to infer output schema and dynamic axes...") with torch.no_grad(): # Deepcopy inputs, since input values may change after model run. sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs) diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index f8d3a6e779..0eb6790d7a 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -288,6 +288,9 @@ class _RuntimeOptions: # Cache exported model self.ortmodule_cache_dir = "" + # Experimental features. + self.enable_zero_stage3_support = False # Once enabled, cannot be disabled. + # Override the feature config if it exists in os env. self._override_from_env_vars() @@ -365,3 +368,7 @@ class _RuntimeOptions: if "ORTMODULE_CACHE_DIR" in os.environ: self._logger.info("ORTModule cache optimization is ON.") self.ortmodule_cache_dir = os.getenv("ORTMODULE_CACHE_DIR") + + # Experimental features. + if "ORTMODULE_ENABLE_ZERO_STAGE3" in os.environ and int(os.getenv("ORTMODULE_ENABLE_ZERO_STAGE3")) == 1: + self.enable_zero_stage3_support = True diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index c88c744ca9..acf2698d55 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -9,6 +9,7 @@ from onnxruntime.training.utils.torch_io_helper import ( extract_data_and_schema, unflatten_data_using_schema, ) +from onnxruntime.training.utils.torch_type_map import pytorch_dtype_to_onnx __all__ = [ "PrimitiveType", @@ -16,4 +17,5 @@ __all__ = [ "ORTModelInputOutputSchemaType", "extract_data_and_schema", "unflatten_data_using_schema", + "pytorch_dtype_to_onnx", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/__init__.py b/orttraining/orttraining/python/training/utils/hooks/__init__.py index 91e919b1c5..89c0d44abb 100644 --- a/orttraining/orttraining/python/training/utils/hooks/__init__.py +++ b/orttraining/orttraining/python/training/utils/hooks/__init__.py @@ -3,14 +3,35 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- + +import torch + __all__ = [ "StatisticsSubscriber", "GlobalSubscriberManager", - "_InspectActivation", + "inspect_activation", + "ZeROOffloadSubscriber", + "configure_ort_compatible_zero_stage3", ] -from ._statistics_subscriber import StatisticsSubscriber -from ._subscriber_manager import SubscriberManager, _InspectActivation +from ._statistics_subscriber import StatisticsSubscriber, _InspectActivation +from ._subscriber_manager import SubscriberManager +from ._zero_offload_subscriber import ZeROOffloadSubscriber, configure_ort_compatible_zero_stage3 # Define a global uninitialized subscriber manager for usage where it is needed by different Python files. GlobalSubscriberManager = SubscriberManager() + + +def inspect_activation(activation_name: str, tensor: torch.Tensor) -> torch.Tensor: + for sub in GlobalSubscriberManager._subscribers: + if isinstance(sub, StatisticsSubscriber): + return _InspectActivation.apply( + activation_name, + None, + GlobalSubscriberManager.get_run_context(), + tensor, + sub.module_post_forward_impl, + sub.module_pre_backward_impl, + ) + + raise RuntimeError("StatisticsSubscriber is not registered, cannot inspect activation.") diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index 0dd06eee13..6c8027b2fe 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -7,11 +7,90 @@ import os import shutil import warnings from pathlib import Path -from typing import Union +from typing import List, Optional, Tuple, Union +import onnx import torch -from ._subscriber_base import SubscriberBase +from ._subscriber_base import RuntimeStates, SubscriberBase + + +class _InspectActivation(torch.autograd.Function): + """ + This class is used to run the subscriber's forward and backward functions. + The function will be called by two kinds of callers: + 1. SubscriberManager calls it for each registered nn.Module. + 2. Users who want to inspect the activation tensor at any place of model definition code. + """ + + @staticmethod + def forward( + ctx, + activation_name: str, + module_idx: Optional[int], + run_ctx: RuntimeStates, + input_tensor: torch.Tensor, + module_post_forward, + module_pre_backward, + ): + """ + Args: + ctx: context object to store intermediate information. + activation_name: the name of the activation tensor. + module_idx: unit id of the module (address of the module instance). + run_ctx: runtime context. + For call case 2 - need retrieve the runtime state from GlobalSubscriberManager. + input_tensor: the activation tensor. + + Make sure there is a same number of `tensor` type inputs and outputs. + This is enforced by ORT's PythonOp's schema check. + """ + depth = -1 + if module_idx is not None: + depth = run_ctx.global_states.module_index_to_depth[module_idx] + + input_tensor_copied = None + if input_tensor is None or not isinstance(input_tensor, torch.Tensor): + input_tensor_copied = input_tensor + else: + input_tensor_copied = input_tensor.detach().clone() + + ctx.current_step = run_ctx.global_states.execution_step + ctx.name = activation_name + ctx.id = module_idx + ctx.depth = depth + ctx.module_pre_backward = module_pre_backward + + module_post_forward(input_tensor_copied, depth, activation_name, ctx.current_step) + + return input_tensor.detach() if input_tensor is not None else None + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + val = None + if grad_output is None or not isinstance(grad_output, torch.Tensor): + val = grad_output + else: + val = grad_output.detach().clone() + + ctx.module_pre_backward(val, ctx.depth, ctx.name, ctx.current_step) + + return ( + None, + None, + None, + grad_output.detach() if grad_output is not None else None, + None, + None, + ) + + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes class StatisticsSubscriber(SubscriberBase): @@ -68,6 +147,15 @@ class StatisticsSubscriber(SubscriberBase): "Set override_output_dir=True for StatisticsSubscriber if this is the intention." ) + def post_forward_tensor_apply_impl( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + module_index = run_rtx.global_states.module_to_module_index[module] + name = f"{module.__class__.__name__}_{module_index}_{tensor_index}th_output" + return _InspectActivation.apply( + name, module_index, run_rtx, tensor, self.module_post_forward_impl, self.module_pre_backward_impl + ) + def module_post_forward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int): output_file_path = os.path.join(f"{self._output_dir}", f"step_{step}") return self._summarize_activations(activation, depth, name, output_file_path, True) diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py index 572f2fb99b..1b9a6fc91e 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py @@ -5,34 +5,54 @@ import sys -from typing import Union +from typing import Optional, Tuple import torch +from onnxruntime.training.utils import ORTModelInputOutputType + + +class RuntimeStates: + """ + A data struct holding states for runtime context. + > Global states that are one-time collected during model hook registration. A global execution step is + also initialized to reflect how many steps have been executed, it will get updated after each step + completes its forward path. + """ + + class _GlobalStates: + def __init__(self): + # Used to track current execution step, e.g. how many forward/backward path is called. + self.execution_step = 0 + # Used to store the depth of each module, which indicate the indentation level of the module. + self.module_index_to_depth = {} + # Used to store the unique id of each sequential activation. + self.module_to_module_index = {} + + def __init__(self): + self.global_states = RuntimeStates._GlobalStates() + class SubscriberBase: """ Base class for all module hook subscribers. - Currently, the hook here only means post-forward hook and pre-backward hook. - A module hook subscriber is a class that implements the `module_post_forward_impl` and `module_pre_backward_impl` - function. - > The post_forward hook is called after the activation is generated in the forward path. - > The pre_backward hook is called before the activation gradient is computed. + A module hook subscriber is a class that allow define custom actions to be executed during the nn.Module's hooks. + Two types of APIs can be used to define custom actions: + 1. Module level interfaces: + pre_forward_module_apply - called inside the nn.Module's pre-forward hook. + post_forward_module_apply - called inside the nn.Module's post-forward hook. + post_forward_outmost_module_apply - called inside the nn.Module's post-forward hook, but only for the outmost module. + 2. Tensor level interfaces: + pre_forward_tensor_apply - called inside the nn.Module's pre-forward hook, for each input tensor. + post_forward_tensor_apply - called inside the nn.Module's post-forward hook, for each output tensor. - The post_forward path: - Module_A generates activation tensor_a --> Post forward hook (calling subscribers' forward one by one) --> - Module_B generates activation tensor_b --> ... - - The pre_backward path: - Module_B backward run, tensor_b's gradient is computed as tensor_b_grad --> - Pre-backward hook (calling subscribers' backward one by one) --> - Module_A backward run, tensor_a's gradient is computed as tensor_a_grad - - Be noted: the "Pre"/"Post" is described from the perspective of Module_A. + For ORT runs, tensor's flows are important, that's the reason we have tensor input as function input, + and tensor output as function output for all the APIs. + With this, the overall flow can be traced as a data flow graph (DAG). """ - def __init__(self, start_step: Union[None, int], end_step: Union[None, int]): + def __init__(self, start_step: Optional[int], end_step: Optional[int]): """ Steps in [start_step, end_step) will run the subscriber's actions, and other steps will skip. If start_step is None, 0 is given; if end_step is None, sys.maxsize is given. @@ -40,34 +60,152 @@ class SubscriberBase: self._start_step: int = start_step if start_step is not None else 0 self._end_step: int = end_step if end_step is not None else sys.maxsize - def module_post_forward(self, activation: torch.Tensor, depth: int, name: str, step: int): - """ - This function will be run after the torch Module forward is completed. + def pre_forward_module_apply( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is called inside the nn.Module's pre-forward hook. Args: - activation: Tensor to be inspected. - depth: The indent level of the torch Module generating `activation`. - name: The unique name for the `activation`. - step: Current execution step. - """ - if self._start_step <= step < self._end_step: - self.module_post_forward_impl(activation, depth, name, step) + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + args (ORTModelInputOutputType): The positional arguments that are passed to the module's pre-forward hook. + kwargs (ORTModelInputOutputType): The keyword arguments that are passed to the module's pre-forward hook. + + Returns: + Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated args and kwargs. - def module_pre_backward(self, activation: torch.Tensor, depth: int, name: str, step: int): """ - This function will be run before the torch Module backward run. + if self._need_skip_step(run_rtx.global_states.execution_step): + return args, kwargs + + updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_rtx, module, args, kwargs) + return updated_args, updated_kwargs + + def pre_forward_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + return args, kwargs + + def pre_forward_tensor_apply( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + """This function is called inside the nn.Module's pre-forward hook. Args: - activation: Tensor to be inspected. - depth: The indent level of the torch Module generating `activation`. - name: The unique name for the `activation`. - step: Current execution step. + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + tensor_index (int): The index of the tensor in the input tensor list. + tensor (torch.Tensor): The tensor is one of module's forward inputs. """ - if self._start_step <= step < self._end_step: - self.module_pre_backward_impl(activation, depth, name, step) + if self._need_skip_step(run_rtx.global_states.execution_step): + return tensor - def module_post_forward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int): - raise NotImplementedError() + return self.pre_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor) - def module_pre_backward_impl(self, activation: torch.Tensor, depth: int, name: str, step: int): - raise NotImplementedError() + def pre_forward_tensor_apply_impl( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + return tensor + + def post_forward_module_apply( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is called inside the nn.Module's post-forward hook. + + Args: + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward + hook as input. + outputs (ORTModelInputOutputType): The outputs arguments that are passed to the module's post-forward + hook as input. + + Returns: + Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs. + """ + if self._need_skip_step(run_rtx.global_states.execution_step): + return args, outputs + + return self.post_forward_module_apply_impl(run_rtx, module, args, outputs) + + def post_forward_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + return args, outputs + + def post_forward_tensor_apply( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + """This function is called inside the nn.Module's post-forward hook. + + Args: + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + tensor_index (int): The index of the tensor in the output tensor list. + tensor (torch.Tensor): The tensor is one of module's forward outputs. + + Returns: + torch.Tensor: Updated tensor. + """ + if self._need_skip_step(run_rtx.global_states.execution_step): + return tensor + + return self.post_forward_tensor_apply_impl(run_rtx, module, tensor_index, tensor) + + def post_forward_tensor_apply_impl( + self, run_rtx: RuntimeStates, module: torch.nn.Module, tensor_index: int, tensor: torch.Tensor + ) -> torch.Tensor: + return tensor + + def post_forward_outmost_module_apply( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is called inside the outmost nn.Module's post-forward hook. + + Args: + run_rtx (RuntimeStates): The runtime states of SubscriberManager. + module (torch.nn.Module): The module that is being executed. + args (ORTModelInputOutputType): The inputs arguments that are passed to the module's post-forward + hook as input. + outputs (ORTModelInputOutputType): The outputs arguments that are passed to the module's post-forward + hook as input. + + Returns: + Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: Updated inputs and outputs. + """ + if self._need_skip_step(run_rtx.global_states.execution_step): + return args, outputs + + return self.post_forward_outmost_module_apply_impl(run_rtx, module, args, outputs) + + def post_forward_outmost_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + return args, outputs + + def _need_skip_step(self, current_step: int) -> bool: + return current_step < self._start_step or current_step >= self._end_step diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index 72208b7228..db38f58d8f 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -3,114 +3,31 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from collections import abc -from typing import Callable, List, Optional, Tuple, Union + +import inspect +from contextlib import contextmanager +from typing import List, Optional, Set, Tuple, Union import onnx import torch -from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.utils import extract_data_and_schema, unflatten_data_using_schema -from ._subscriber_base import SubscriberBase +from ._subscriber_base import RuntimeStates, SubscriberBase + +ORT_NO_INCREASE_GLOBAL_STEP = [False] -class _RuntimeStates: +@contextmanager +def no_increase_global_step(): + """During ONNX model export phase, forward run is triggered, but we don't want to increase the global step, then + Then the first iteration run will still start with 0, aligned with PyTorch's first iteration run. """ - A data struct holding states for runtime context. Tho kinds of states are included: - > Global states that are one-time collected during model hook registration. A global execution step is - also initialized to reflect how many steps have been executed, it will get updated after each step - completes its forward path. - > Intra-execution step states, initialized and cleaned up intended only for current execution step. - Usually, it carries intermediate information during the model execution. - """ - - class _GlobalStates: - def __init__(self): - # Used to track current execution step, e.g. how many forward/backward path is called. - self.execution_step = 0 - # Used to store the depth of each module, which indicate the indentation level of the module. - self.module_index_to_depth = {} - # Used to store the unique id of each sequential activation. - self.module_to_module_index = {} - - self.subscribers = set() - - class _ExecutionStepStates: - def __init__(self): - # Used to store the activation tensor names, if already handled, then skipped. - # Need to clear after each step. - self.observed_activation_names = {} - - def __init__(self): - self.global_states = _RuntimeStates._GlobalStates() - self.reset_step_states() - - def reset_step_states(self): - self.execution_step_states = _RuntimeStates._ExecutionStepStates() - - -class _InspectActivation(torch.autograd.Function): - """ - This class is used to run the subscriber's forward and backward functions. - The function will be called by two kinds of callers: - 1. SubscriberManager calls it for each registered nn.Module. - 2. Users who want to inspect the activation tensor at any place of model definition code. - """ - - @staticmethod - def forward( - ctx, activation_name: str, module_idx: Optional[int], run_ctx: _RuntimeStates, input_tensor: torch.Tensor - ): - """ - Args: - ctx: context object to store intermediate information. - activation_name: the name of the activation tensor. - module_idx: - For call case 1 - the unique id of the module that the activation belongs to, it is detected by the - SubscriberManager automatically. - For call case 2 - e.g, _InspectActivation is called by users (NOT by SubscriberManager), module_idx can - be None. - run_ctx: runtime context. - For call case 2 - need retrieve the runtime state from GlobalSubscriberManager. - input_tensor: the activation tensor. - - Make sure there is a same number of `tensor` type inputs and outputs. - This is enforced by ORT's PythonOp's schema check. - """ - depth = -1 - if module_idx is not None: - depth = run_ctx.global_states.module_index_to_depth[module_idx] - - input_tensor_copied = None - if input_tensor is None or not isinstance(input_tensor, torch.Tensor): - input_tensor_copied = input_tensor - else: - input_tensor_copied = input_tensor.detach().clone() - - ctx.current_step = run_ctx.global_states.execution_step - ctx.name = activation_name - ctx.id = module_idx - ctx.depth = depth - ctx.subscribers = run_ctx.global_states.subscribers - - # Run subscribers sequentially. - for subscriber in run_ctx.global_states.subscribers: - subscriber.module_post_forward(input_tensor_copied, depth, activation_name, ctx.current_step) - - return input_tensor.detach() if input_tensor is not None else None - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - val = None - if grad_output is None or not isinstance(grad_output, torch.Tensor): - val = grad_output - else: - val = grad_output.detach().clone() - - for subscriber in ctx.subscribers: - subscriber.module_pre_backward(val, ctx.depth, ctx.name, ctx.current_step) - - return None, None, None, grad_output.detach() if grad_output is not None else None + try: + ORT_NO_INCREASE_GLOBAL_STEP[0] = True + yield + finally: + ORT_NO_INCREASE_GLOBAL_STEP[0] = False @staticmethod def infer_shape( @@ -122,8 +39,7 @@ class _InspectActivation(torch.autograd.Function): class _IncrementStep(torch.autograd.Function): - """ - This class is used to manage the global execution step, e.g. + """This class is used to manage the global execution step, e.g. global step increment by one, once a full forward path is completed and the state clear. This autograd Function is registered as a post-forward hook to the root module. So once the root @@ -132,45 +48,24 @@ class _IncrementStep(torch.autograd.Function): """ @staticmethod - def forward(ctx, run_ctx: _RuntimeStates, input_tensor: torch.Tensor): - """ - Make sure there is a same number of `tensor` inputs and outputs. + def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: + """Make sure there is the same number of `tensor` inputs and outputs. This is enforced by ORT's PythonOp's schema check. """ ctx.current_step = run_ctx.global_states.execution_step ctx.run_ctx = run_ctx - # We cannot do the step incremental here. Imagine the outside-most module has multiple outputs, - # we need to increase the step only at the very last output handling. - # We avoid the complexity to probe the last output handling, and instead, we assume once - # the very first backward of the outside-most module is called, then the forward pass MUST be completed. + if ctx.current_step >= 0: + print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") - # Be noted: it is not safe to register _IncrementStep only for one of the outputs of the outside-most module, - # because we are not sure which output branch is executed earlier, for example. - # OuterMostModuleOutputs - # / \ - # OuterMostModuleOutputs_0_0th_output OuterMostModuleOutputs_0_1th_output - # | | - # PythonOp(_InspectActivation) PythonOp(_InspectActivation) - # | | - # PythonOp(_IncrementStep) graph output - # | - # graph output - # The PythonOp(_InspectActivation) (who relies on global step) after 1th output is possible - # to run before or after PythonOp(_IncrementStep), so increasing the step is not safe. + if ORT_NO_INCREASE_GLOBAL_STEP[0] is False: + ctx.run_ctx.global_states.execution_step += 1 - return input_tensor.detach() if isinstance(input_tensor, torch.Tensor) else input_tensor + return tuple(t.detach().requires_grad_(t.requires_grad) for t in input_tensor_list) @staticmethod - def backward(ctx, grad_output: torch.Tensor): - # In case there are multiple backward calls for multiple outputs of the outside-most module. - if ctx.current_step == ctx.run_ctx.global_states.execution_step: - if ctx.current_step >= 0: - print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") - ctx.run_ctx.global_states.execution_step += 1 - ctx.run_ctx.reset_step_states() - - return None, grad_output.detach() if isinstance(grad_output, torch.Tensor) else grad_output + def backward(ctx, *grad_output: Tuple[Optional[torch.Tensor], ...]) -> Tuple[Optional[torch.Tensor], ...]: + return (None, *tuple(g for g in grad_output)) @staticmethod def infer_shape( @@ -182,23 +77,25 @@ class _IncrementStep(torch.autograd.Function): class SubscriberManager: - """ - This class is used to manage all the subscribers and register the post-forward hook to the root module. - `subscribe()` is used to register a list of subscribers. + """This class is used to manage all the subscribers and register subscribers' custom actions as PyTorch hooks + to the nn.Modules. - Currently, the hook handled here is post forward hook for nn.Module. The hook is registered for all nn.Modules - recursively. Each hook inserts a PythonOp for every tensor output generated by the corresponding module. - Each subscriber implementation is called in the PythonOp's forward function, and backward function. + For the module-level/tensor_level custom actions defined by subscribers, they are registered as corresponding + PyTorch hooks in the sequence of the subscribers' registration order. There is one special handling for global step increment and state clear. A post-forward hook is registered for the outside-most module, which is the root module. In that hook, _IncrementStep is called, which will - increase the step by 1 once the very first time its backward is called (check _IncrementStep for details). + increase the step by 1 once the post forward hook is called if running without no_increase_global_step(). + `no_increase_global_step` is used to skip the step increment during ONNX model export. """ def __init__(self): - self._run_ctx: _RuntimeStates = _RuntimeStates() + self._run_ctx = RuntimeStates() + self._subscribers: Set[SubscriberBase] = set() + self._pre_forward_hooks = [] + self._post_forward_hooks = [] - def subscribe(self, module: Union[torch.nn.Module, ORTModule], subscribers: List[SubscriberBase]): + def subscribe(self, module: torch.nn.Module, subscribers: List[SubscriberBase]): """ The API is called externally to register hooks that are implicitly defined by subscribers. Each time all global states will be cleaned up once called. @@ -207,52 +104,110 @@ class SubscriberManager: raise ValueError("module must be a torch.nn.Module instance") self._reset_all_states() + self._subscribers.clear() - if isinstance(module, ORTModule): - module = module.module + try: + # Put the import here to avoid the module level dependency on onnxruntime.training.ortmodule + from onnxruntime.training.ortmodule import ORTModule + + if isinstance(module, ORTModule): + module = module.module + except ImportError: + pass for subscriber in subscribers: if not isinstance(subscriber, SubscriberBase): raise ValueError("subscriber must be a SubscriberBase instance") - self._run_ctx.global_states.subscribers.add(subscriber) + self._subscribers.add(subscriber) self._initialize(module) - def get_run_context(self) -> _RuntimeStates: + def get_subscriber(self, subscriber_type: type) -> SubscriberBase: + for subscriber in self._subscribers: + if isinstance(subscriber, subscriber_type): + return subscriber + raise RuntimeError(f"Subscriber {subscriber_type} is not registered.") + + def get_run_context(self) -> RuntimeStates: return self._run_ctx def _reset_all_states(self): - self._run_ctx = _RuntimeStates() + self._pre_forward_hooks.clear() + self._post_forward_hooks.clear() + self._run_ctx = RuntimeStates() def _initialize(self, module: torch.nn.Module): - """ - Register hooks for the specified module. - """ - if len(self._run_ctx.global_states.subscribers) == 0: + """Register hooks for the specified module.""" + if len(self._subscribers) == 0: raise RuntimeError("No subscribers are registered.") + def _pre_forward_outmost_module_hook(module, module_inputs): + # This check is to support the case where module is first registered in the subscriber manager, + # then the module and hook are copied, when new module instance runs to the hook, the global states + # are not reset, so the logic depends on the global states will fail. So in the outer-most pre-forward hook + # we reset the global states. + + # Be noted, the first run anyway will run in PyTorch. + if module not in self._run_ctx.global_states.module_to_module_index: + import warnings + + warnings.warn( + "Initialize global states for the first time, this should only happen once for each outmost module." + ) + self._initialize_one_time_global_states(module) + return module_inputs + + module.register_forward_pre_hook(_pre_forward_outmost_module_hook) + next_module_index = [0] - # Register post forward hook for every module, inside the hook, we loop every tensor output of the module, - # and wrap it with an autograd Function called _InspectActivation (which takes in a tensor and returns the same - # tensor). In this way, we keep ORT and PyTorch run have the same boundary to check activation equality. self._register_hooks_recursively(module, 1, next_module_index) # Register post forward hook for the outside-most module, then we increase the dump step. - # Be noted, if backward is not triggered, the global dump step remains the original number, - # which means the subsequent run will override the previous dump files. This indeed happens to imagine ORTModule - # firstly export graph (run the forward only), after the gradient graph is built, another forward+backward is - # triggered, override the previous dump files. - def _post_forward_outmost_module_hook(module, _, module_outputs): - def _apply_to_tensors_func(_, outputs): - return _IncrementStep.apply(self._run_ctx, outputs) + def _post_forward_outmost_module_hook(module, module_inputs, module_outputs): + # Call post outmost module forward custom actions for subscribers + for sub in self._subscribers: + module_inputs, module_outputs = sub.post_forward_outmost_module_apply( + self._run_ctx, module, module_inputs, module_outputs + ) - return self._apply_function_to_tensors(module, module_outputs, _apply_to_tensors_func) + flatten_output_tensor_list, output_schema = extract_data_and_schema(module_outputs) + output_tensors = _IncrementStep.apply(self._run_ctx, *flatten_output_tensor_list) + restored_outputs = unflatten_data_using_schema(output_tensors, output_schema) + + return restored_outputs module.register_forward_hook(_post_forward_outmost_module_hook) + def _initialize_one_time_global_states(self, module: torch.nn.Module): + def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: List[int]): + """ + Called to register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s, + this function is called recursively by passing in `next_module_index` - a list of int to maintain a + global incremental unique module id. + + Args: + module: torch.nn.Module to register hook. + depth: the indent of the module compared with the outside-most Module. + next_module_index: list of int, carrying a global unique module index that can be used next. + """ + module_index = next_module_index[0] + module.id = module_index # STAGE3WARN: needed by DeepSpeed + self._run_ctx.global_states.module_index_to_depth[module_index] = depth + self._run_ctx.global_states.module_to_module_index[module] = module_index + + for child in module.children(): + if ( + isinstance(child, torch.nn.Module) + and child not in self._run_ctx.global_states.module_to_module_index + ): + next_module_index[0] += 1 + _reset_recursively(child, depth + 1, next_module_index) + + next_module_index = [0] + _reset_recursively(module, 1, next_module_index) + def _register_hooks_recursively(self, module: torch.nn.Module, depth: int, next_module_index: List[int]): - """ - Called to register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s, + """Register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s, this function is called recursively by passing in `next_module_index` - a list of int to maintain a global incremental unique module id. @@ -262,6 +217,7 @@ class SubscriberManager: next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] + module.id = module_index # STAGE3WARN: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index @@ -270,69 +226,54 @@ class SubscriberManager: next_module_index[0] += 1 self._register_hooks_recursively(child, depth + 1, next_module_index) - def _post_forward_module_hook(module, _, module_outputs): - if module in self._run_ctx.global_states.module_to_module_index and isinstance(module, torch.nn.Module): - module_index = self._run_ctx.global_states.module_to_module_index[module] + def _pre_forward_module_with_kwargs_hook(module, module_inputs, kwargs): + # Module level hook + for sub in self._subscribers: + module_inputs, kwargs = sub.pre_forward_module_apply(self._run_ctx, module, module_inputs, kwargs) - def _apply_to_tensors_func(index, activation_tensor): - name = f"{module.__class__.__name__}_{module_index}_{index}th_output" - if name not in self._run_ctx.execution_step_states.observed_activation_names: - self._run_ctx.execution_step_states.observed_activation_names[name] = True - return _InspectActivation.apply(name, module_index, self._run_ctx, activation_tensor) + # Tensor level hook + flatten_positional_input_tensor_list, input_schema = extract_data_and_schema(module_inputs) + flatten_keyword_input_tensor_list, keyword_input_schema = extract_data_and_schema(kwargs) - return activation_tensor + for sub in self._subscribers: + tensor_list = [] + for tensor_index, tensor in enumerate(flatten_positional_input_tensor_list): + tensor_list.append(sub.pre_forward_tensor_apply(self._run_ctx, module, tensor_index, tensor)) + flatten_positional_input_tensor_list = tensor_list - return self._apply_function_to_tensors(module, module_outputs, _apply_to_tensors_func) - return module_outputs + tensor_list = [] + for tensor_index, tensor in enumerate(flatten_keyword_input_tensor_list): + tensor_list.append(sub.pre_forward_tensor_apply(self._run_ctx, module, tensor_index, tensor)) + flatten_keyword_input_tensor_list = tensor_list - module.register_forward_hook(_post_forward_module_hook) + module_inputs = unflatten_data_using_schema(flatten_positional_input_tensor_list, input_schema) + kwargs = unflatten_data_using_schema(flatten_keyword_input_tensor_list, keyword_input_schema) - def _is_builtin_type(self, obj): - # https://stackoverflow.com/a/17795199 - return obj.__class__.__module__ in ["__builtin__", "builtins"] + return module_inputs, kwargs - def _apply_function_to_tensors(self, module: torch.nn.Module, data, func: Callable): - """ - Apply func to all tensors in the given object. + def _pre_forward_module_hook(module, module_inputs): + return _pre_forward_module_with_kwargs_hook(module, module_inputs, {}) - Args: - module: the module that generates the tensors. - data: the object that contains activation tensors. - func: the function to apply to the tensors. - """ - tensor_output_idx: List[int] = [0] + def _post_forward_module_hook(module, module_inputs, module_outputs): + # Module level hook + for sub in self._subscribers: + _, module_outputs = sub.post_forward_module_apply(self._run_ctx, module, module_inputs, module_outputs) - def _apply_to_tensors_by_flatten( - module: torch.nn.Module, - index_for_tensor_output: List[int], - outputs, - func: Callable, - ): - if isinstance(outputs, abc.Sequence): - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_by_flatten(module, index_for_tensor_output, output, func) - touched_outputs.append(touched_output) - return outputs.__class__(touched_outputs) + # Tensor level hook + flatten_output_tensor_list, output_schema = extract_data_and_schema(module_outputs) + for sub in self._subscribers: + tensor_list = [] + for tensor_index, tensor in enumerate(flatten_output_tensor_list): + tensor_list.append(sub.post_forward_tensor_apply(self._run_ctx, module, tensor_index, tensor)) + flatten_output_tensor_list = tensor_list - if isinstance(outputs, abc.Mapping): - # apply inplace to avoid recreating dict inherited objects - for key in outputs: - outputs[key] = _apply_to_tensors_by_flatten( - module, - index_for_tensor_output, - outputs[key], - func, - ) - return outputs + return unflatten_data_using_schema(flatten_output_tensor_list, output_schema) - if isinstance(outputs, torch.Tensor): - cur_id = index_for_tensor_output[0] - index_for_tensor_output[0] += 1 - return func(cur_id, outputs) - - if not self._is_builtin_type(outputs): - raise RuntimeError(f"Unknown type {type(outputs)}") - return outputs - - return _apply_to_tensors_by_flatten(module, tensor_output_idx, data, func) + # "with_kwargs" is not available for low versions of PyTorch. + if "with_kwargs" in inspect.signature(module.register_forward_pre_hook).parameters: + self._pre_forward_hooks.append( + module.register_forward_pre_hook(_pre_forward_module_with_kwargs_hook, with_kwargs=True) + ) + else: + self._pre_forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook)) + self._post_forward_hooks.append(module.register_forward_hook(_post_forward_module_hook)) diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py new file mode 100644 index 0000000000..3d42e172ee --- /dev/null +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -0,0 +1,476 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import ctypes +import inspect +import warnings +from collections import OrderedDict +from types import CodeType, FunctionType +from typing import Callable, Dict, List, Optional, Tuple, Union + +import onnx +import torch + +from onnxruntime.training.utils import ( + ORTModelInputOutputType, + extract_data_and_schema, + pytorch_dtype_to_onnx, + unflatten_data_using_schema, +) + +from ._subscriber_base import RuntimeStates, SubscriberBase + + +# Used to monkey patch the original function +# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 +def _setup_zero_stage3_ort_compatible_hooks(self): + self.hierarchy = 0 + + from onnxruntime.training.utils.hooks import SubscriberManager, ZeROOffloadSubscriber + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer + + # Each DeepSpeed engine has a separate subscriber manager. + self._offload_subscriber_manager = SubscriberManager() + self._offload_subscriber_manager.subscribe( + self.module, [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] + ) + self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) + self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) + + # Add top module to stack trace + global FWD_MODULE_STACK # noqa: PLW0602 + FWD_MODULE_STACK.append(self.module) + + +# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104 +# In the original logic, if bias is None, after export to ONNX, None becomes a constant, so backward op complains +# output count more than needed. +def _zero3_linear_wrap_ort_compatible(input, weight, bias=None): + from deepspeed.runtime.zero.linear import LinearFunctionForZeroStage3 + + return LinearFunctionForZeroStage3.apply(input, weight, bias) + + +class _ZeROOffloadOneTimeInitializer: + """Store the hook functions from DeepSpeed ZeRO offload. + + Hook functions code collected from DeepSpeed. + """ + + def __init__(self): + self._code_store: OrderedDict[str, CodeType] = {} + + def collect_code(self, function: Callable): + """Collect the function `CodeType`, which is the code object of the function.""" + code_obj = function.__code__ + for c in code_obj.co_consts: + if inspect.iscode(c): + self._code_store[c.co_name] = c + + +_zero_offload_one_time_initializer = None + +try: + # Have to import below explicitly, otherwise it complains about _apply_to_tensors_only not found. + # The hooks reference functions or classes in that file. + from deepspeed.runtime.zero.parameter_offload import * # noqa: F403 + from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload, _apply_to_tensors_only # noqa: F401 + from deepspeed.utils import instrument_w_nvtx # noqa: F401 + + # Used to collect the hook functions's code object from DeepSpeed ZeRO offload, this should be initialized only once. + if _zero_offload_one_time_initializer is None: + _zero_offload_one_time_initializer = _ZeROOffloadOneTimeInitializer() + _zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload._register_hooks_recursively) + _zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload.setup_zero_stage3_hooks) + + # This is the function to enable ORT ZeRO offload. + def configure_ort_compatible_zero_stage3(): + """Configure ZeRO stage3 to be ORT compatible. + + This function will overwrite the original DeepSpeed ZeRO stage3 hooks to make it ORT compatible. + """ + + # Only done once no matter how many times this function is called for different modules. + DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _setup_zero_stage3_ort_compatible_hooks + + from deepspeed.runtime.zero.linear import zero3_linear_wrap + + if torch.nn.functional.linear is zero3_linear_wrap: + torch.nn.functional.linear = _zero3_linear_wrap_ort_compatible + +except ImportError as e: + warnings.warn(f"DeepSpeed import error {e}") + + def configure_ort_compatible_zero_stage3(): + raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") + + +def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.parameter.Parameter]: + """Retrieve the parameters for this module. + + Logic adapted from + https://github.com/microsoft/DeepSpeed/blob/9d79cfd1e90cae9306dc1b5837d374b2c9489ac8/deepspeed/runtime/zero/partitioned_param_coordinator.py#L267 + """ + from deepspeed.runtime.zero.partitioned_param_coordinator import iter_params + + # Retrive the parameters that are not available for this module. + partitioned_params = [param for param in iter_params(module)] + + return partitioned_params + + +def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: + """Retrieve all the parameters that are offloaded.""" + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + all_offloaed_params = OrderedDict() + for name, param in module.named_parameters(): + if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + all_offloaed_params[name] = param + + return all_offloaed_params + + +class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): + """This function is a common bridge to call original PyTorch's + pre_forward_function and post_backward_function. + """ + + @staticmethod + def forward( + ctx, + module, + pre_forward_with_kwargs_function, + post_backward_function, + args_schema, + kwargs_schema, + args_tensor_count, + kwargs_tensor_count, + *tensor_list, + ): + """ + Args: + ctx: context object + module: the module to be called + pre_forward_with_kwargs_function: the function to be called before forward (PyTorch's pre_forward_function) + post_backward_function: the function to be called after backward (PyTorch's post_backward_function) + args_schema: the schema of the args, used to reconstruct the args in original form in + PyTorch's pre_forward_function's inputs. + kwargs_schema: the schema of the kwargs, used to reconstruct the kwargs in original form in + PyTorch's pre_forward_function's inputs. + args_tensor_count: the number of tensors in args. + kwargs_tensor_count: the number of tensors in kwargs. + tensor_list: the list of tensors, the first args_tensor_count tensors are args, the next + kwargs_tensor_count tensors are kwargs, the rest are the parameters for offload. + """ + args_tensors = tensor_list[:args_tensor_count] + kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count] + + args = unflatten_data_using_schema(args_tensors, args_schema) + kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) + + # We will re-retrieve the parameter tensors other than use the one passed in input (of size 0 for + # those partitioned params). + # This is required for ORT run because in ORT graph, the tensor of size 0 will always be size 0 + # (this step is not necessary for PyTorch run, because PyTorch will re-use the same tensor + # while .data got updated to full-sized data after pre_forward_with_kwargs_function is called). + partitioned_params = _get_params_for_current_module(module) + ctx.partitioned_params = partitioned_params + + f_ret = pre_forward_with_kwargs_function(module, args, kwargs) + + if f_ret is None: + updated_args, updated_kwargs = args, kwargs + else: + assert isinstance(f_ret, tuple) + updated_args, updated_kwargs = f_ret + + ctx.module = module + ctx.post_backward_function = post_backward_function + + updated_args_tensors, _ = extract_data_and_schema(updated_args) + updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs) + + rets = tuple(updated_args_tensors + updated_kwargs_tensors) + rets += tuple([p.detach().requires_grad_(p.requires_grad) for p in partitioned_params]) + + # PyTorch exporter does not support an empty list of tensors, so we have this check. + assert len(rets) != 0 + return rets + + @staticmethod + def backward(ctx, *grads): + updated_grads = grads + if ctx.post_backward_function is not None: + ret = ctx.post_backward_function(ctx.module, grads) + if ret is not None: + updated_grads = ret + + # TODO(pengwa) Update grad for partitioned parameters. + input_count = len(updated_grads) - len(ctx.partitioned_params) + zeros = [torch.zeros(0, dtype=p.dtype, device=p.device) for p in ctx.partitioned_params] + zero_grads = updated_grads[:input_count] + tuple(zeros) + + return (None, None, None, None, None, None, None, *zero_grads) + + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + input_pointer_scalars_attr_name = "input_pointer_scalars" + found = [attr for attr in node.attribute if attr.name == input_pointer_scalars_attr_name] + assert len(found) == 1 + input_pointer_scalars = found[0].ints + + # Restore the nn.Module from the pointer. + module = ctypes.cast(input_pointer_scalars[0], ctypes.py_object).value + + partitioned_params = _get_params_for_current_module(module) + tensor_output_shapes = tensor_input_shapes + tensor_output_dtypes = tensor_input_dtypes + start_offset = len(tensor_input_shapes) - len(partitioned_params) + for index, param in enumerate(partitioned_params): + tensor_output_shapes[start_offset + index] = list(param.ds_shape) + tensor_output_dtypes[start_offset + index] = pytorch_dtype_to_onnx(param.dtype) + assert len(tensor_output_shapes) == len(tensor_input_shapes) + assert len(tensor_output_dtypes) == len(tensor_input_dtypes) + + return tensor_output_shapes, tensor_output_dtypes + + +class ORTZeROOffloadPostForwardFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + module, + post_forward_function, + pre_backward_function, + output_schema, + *output_tensors, + ): + """ + Args: + ctx: context object + module: the module to be called + post_forward_function: the function to be called after forward (PyTorch's post_forward_function) + pre_backward_function: the function to be called before backward (PyTorch's pre_backward_function) + output_schema: the schema of the output, used to reconstruct the output in original form in + PyTorch's post_forward_function's inputs. + output_tensors: the list of tensors. + + """ + outputs = unflatten_data_using_schema(output_tensors, output_schema) + + # STAGE3WARN: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. + updated_outputs = post_forward_function(module, None, outputs) + + if updated_outputs is None: + updated_output_tensors = output_tensors + else: + updated_output_tensors, _ = extract_data_and_schema(updated_outputs) + + ctx.module = module + ctx.pre_backward_function = pre_backward_function + rets = [o.detach().requires_grad_(o.requires_grad) for o in updated_output_tensors] + return tuple(rets) + + @staticmethod + def backward(ctx, *grads): + updated_args = grads + if ctx.pre_backward_function is not None: + ret = ctx.pre_backward_function(ctx.module, grads) + if ret is not None: + updated_args = ret + return (None, None, None, None, *updated_args) + + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + + +class _ZeROOffloadFunctions: + def __init__(self, one_time_init: _ZeROOffloadOneTimeInitializer, offloader) -> None: + self._function_store: OrderedDict[str, FunctionType] = {} + self._one_time_init = one_time_init + for name, code in self._one_time_init._code_store.items(): + cell = self._create_closure_for_ds_hook_function(offloader) + self._function_store[name] = FunctionType(code, globals(), code.co_name, None, (cell,)) + + def get(self, name: str) -> FunctionType: + return self._function_store[name] + + def _create_closure_for_ds_hook_function(self, offloader): + # https://stackoverflow.com/questions/17395338/how-to-access-a-function-inside-a-function + def make_closure_cell(_self): + def nested(): + return _self + + return nested.__closure__[0] + + cell = make_closure_cell(offloader) + return cell + + +class ZeROOffloadSubscriber(SubscriberBase): + """This subscriber is used to enable ZeRO Offload feature in a way compatible with ORTModule.""" + + def __init__(self, offloader, one_time_init: _ZeROOffloadOneTimeInitializer, enable_debug_info: bool = False): + super().__init__(None, None) + self._offloader = offloader + self._functions = _ZeROOffloadFunctions(one_time_init, self._offloader) + self._enable_debug_info = enable_debug_info + + def pre_forward_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + kwargs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is a dispatcher to call DeepSpeed stage3 pre forward hooks in sequence. + + All hook functions can be retrieved from the function store, due to exporter only supports a list of tensors as + input and output for torch.autograd.Function, so we do flatten and unflatten here. + + """ + + args_tensors, args_schema = extract_data_and_schema(args) + kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) + + partitioned_params = _get_params_for_current_module(module) + + _pre_forward_module_hook = self._functions.get("_pre_forward_module_hook") + + args_tensor_count = len(args_tensors) + kwargs_tensor_count = len(kwargs_tensors) + + def _wrap_pre_forward_module_hook(module, args, kwargs): + rets = _pre_forward_module_hook(module, args) + updated_args, updated_kwargs = args, kwargs + if rets is not None: + updated_args = rets + + # STAGE3WARN: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. + module.ds_grads_remaining = 0 + return updated_args, updated_kwargs + + all_tensors = args_tensors + kwargs_tensors + partitioned_params + + self._check_all_tensor(all_tensors, module, "pre_forward_module_apply_impl input check") + + rets = ORTZeROOffloadPreForwardFunction.apply( + module, + _wrap_pre_forward_module_hook, + None, + args_schema, + kwargs_schema, + args_tensor_count, + kwargs_tensor_count, + *all_tensors, + ) + + self._check_all_tensor(rets, module, "pre_forward_module_apply_impl output check") + + updated_args_tensors = rets[:args_tensor_count] + updated_kwargs_tensors = rets[args_tensor_count : args_tensor_count + kwargs_tensor_count] + + updated_args = unflatten_data_using_schema(updated_args_tensors, args_schema) + updated_kwargs = unflatten_data_using_schema(updated_kwargs_tensors, kwargs_schema) + + _post_backward_module_hook = self._functions.get("_post_backward_module_hook") + # STAGE3WARN: Other part of _post_backward_module_hook can be traced correctly so we don't need to + # wrap with PythonOp. + updated_args = _post_backward_module_hook(module, updated_args) + + return updated_args, updated_kwargs + + def post_forward_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + """This function is a dispatcher to call DeepSpeed stage3 post forward hooks in sequence. + + All hook functions can be retrieved from function store, due to exporter only supports a list of tensors as + input and output for torch.autograd.Function, so we do flatten and unflatten here. + + """ + + outputs_tensors, outputs_schema = extract_data_and_schema(outputs) + + _post_forward_module_hook = self._functions.get("_post_forward_module_hook") + + def _wrap_post_forward_module_hook(module, input, outputs): + # STAGE3WARN: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. + from deepspeed.runtime.zero.partition_parameters import is_zero_param + + updated_outputs = _post_forward_module_hook(module, input, outputs) + if updated_outputs: + for updated_output in updated_outputs: + # restore zero param attributes if those get stripped by `backward_function` + if not is_zero_param(updated_output) and is_zero_param(outputs): + updated_output.ds_param_alias = outputs + return updated_outputs + else: + return outputs + + self._check_all_tensor(outputs_tensors, module, "post_forward_module_apply_impl input check") + + updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply( + module, _wrap_post_forward_module_hook, None, outputs_schema, *outputs_tensors + ) + + self._check_all_tensor(updated_outputs_tensors, module, "post_forward_module_apply_impl output check") + + assert len(updated_outputs_tensors) == len(outputs_tensors) + + # WARN: we assume updated_output_tensors can REUSE the outputs_schema. + updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) + + _pre_backward_module_hook = self._functions.get("_pre_backward_module_hook") + # STAGE3WARN: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN: part of the original _pre_backward_module_hook can be traced correctly so we moved them into + # _wrap_post_forward_module_hook above. + updated_outputs = _pre_backward_module_hook(module, None, updated_outputs) + + return args, updated_outputs + + def post_forward_outmost_module_apply_impl( + self, + run_rtx: RuntimeStates, + module: torch.nn.Module, + args: ORTModelInputOutputType, + outputs: ORTModelInputOutputType, + ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + outputs_tensors, outputs_schema = extract_data_and_schema(outputs) + + _end_of_forward_hook = self._functions.get("_end_of_forward_hook") + self._check_all_tensor(outputs_tensors, module, "post_forward_outmost_module_apply_impl input check") + + updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply( + module, _end_of_forward_hook, None, outputs_schema, *outputs_tensors + ) + + self._check_all_tensor(updated_outputs_tensors, module, "post_forward_outmost_module_apply_impl output check") + + assert len(updated_outputs_tensors) == len(outputs_tensors) + updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) + return args, updated_outputs + + def _check_all_tensor(self, tensor_list: Tuple[torch.Tensor], module: torch.nn.Module, name: str): + if not self._enable_debug_info: + return + + for t in tensor_list: + if not isinstance(t, torch.Tensor): + raise RuntimeError(f"{name} fail: {module.__class__.__name__}, input type: {type(t)}") diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py new file mode 100644 index 0000000000..699747723f --- /dev/null +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + + +from typing import Union + +import torch + +# Mapping from pytorch scalar type to onnx scalar type. +_CAST_PYTORCH_TO_ONNX = { + "Byte": [torch.onnx.TensorProtoDataType.UINT8, torch.uint8], + "Char": [torch.onnx.TensorProtoDataType.INT8, torch.int8], + "Double": [torch.onnx.TensorProtoDataType.DOUBLE, torch.double], + "Float": [torch.onnx.TensorProtoDataType.FLOAT, torch.float], + "Half": [torch.onnx.TensorProtoDataType.FLOAT16, torch.half], + "Int": [torch.onnx.TensorProtoDataType.INT32, torch.int], + "Long": [torch.onnx.TensorProtoDataType.INT64, torch.int64], + "Short": [torch.onnx.TensorProtoDataType.INT16, torch.short], + "Bool": [torch.onnx.TensorProtoDataType.BOOL, torch.bool], + "ComplexFloat": [torch.onnx.TensorProtoDataType.COMPLEX64, torch.complex64], + "ComplexDouble": [torch.onnx.TensorProtoDataType.COMPLEX128, torch.complex128], + "BFloat16": [torch.onnx.TensorProtoDataType.BFLOAT16, torch.bfloat16], + # Not yet defined in torch. + # "Float8E4M3FN": torch.onnx.TensorProtoDataType.FLOAT8E4M3FN, + # "Float8E4M3FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E4M3FNUZ, + # "Float8E5M2": torch.onnx.TensorProtoDataType.FLOAT8E5M2, + # "Float8E5M2FNUZ": torch.onnx.TensorProtoDataType.FLOAT8E5M2FNUZ, + "Undefined": [torch.onnx.TensorProtoDataType.UNDEFINED, None], +} + + +_DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _CAST_PYTORCH_TO_ONNX.items()} + + +def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: + """Converts a pytorch dtype or scalar type string to an onnx dtype.""" + dtype = dtype_or_scalar_type + if isinstance(dtype, str): + if dtype not in _CAST_PYTORCH_TO_ONNX: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _CAST_PYTORCH_TO_ONNX[dtype][0] + + if dtype not in _DTYPE_TO_ONNX: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _DTYPE_TO_ONNX[dtype] diff --git a/orttraining/orttraining/test/python/orttraining_test_hooks.py b/orttraining/orttraining/test/python/orttraining_test_hooks.py index 4fb416e640..a58b3919c5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_hooks.py +++ b/orttraining/orttraining/test/python/orttraining_test_hooks.py @@ -8,7 +8,7 @@ import pytest import torch from onnxruntime.training.ortmodule import ORTModule -from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber, _InspectActivation +from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber, inspect_activation class NeuralNetSingleOutput(torch.nn.Module): @@ -147,9 +147,9 @@ class NeuralNetUserAnnotateIntermediateTensor(torch.nn.Module): def forward(self, input1, input2): model_input = input1 + input2 out = self.fc1(model_input) - out = _InspectActivation.apply("fc1_out", None, GlobalSubscriberManager.get_run_context(), out) + out = inspect_activation("fc1_out", out) out = self.relu(out) - out = _InspectActivation.apply("relu_out", None, GlobalSubscriberManager.get_run_context(), out) + out = inspect_activation("relu_out", out) out = self.fc2(out) return out