Model post process for zero stage3 training (#17187)

### Model post process for zero stage3 training

This is the last change to make single GPU/Multiple GPUs run pass. 

Design details:
https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9

`PyTorch` runs with ZeROOffloadSubscriber:

```
  model = prepare_model(...)
  from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
  configure_ort_compatible_zero_stage3()
```

`ORTModule` runs with ZeROOffloadSubscriber:

```
  os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1'
  from onnxruntime.training.ortmodule import ORTModule
  model = ORTModule(self.model)
```

It will be fairly easy to debug convergence issue if both ORT and
PyTorch can run the same offload path.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2023-09-22 08:54:25 +08:00 committed by GitHub
parent 498b60d8a4
commit 6b7bce5ec9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 618 additions and 170 deletions

View file

@ -28,7 +28,8 @@ class PythonOpShapeInferStore:
@classmethod
def register(cls, kclass: torch.autograd.Function) -> None:
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined.
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod
"infer_shape" defined.
The signature of the shape inference function should be:
@staticmethod
@ -51,6 +52,11 @@ class PythonOpShapeInferStore:
if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP:
cls._CLASS_MAP[kclass_name] = kclass.infer_shape
@classmethod
def register_func(cls, name: str, func: Callable) -> None:
"""Register a shape inference function for a torch.autograd.Function by name."""
cls._CLASS_MAP[name] = func
@classmethod
def get_shape_infer(cls, name: str) -> Optional[Callable]:
return cls._CLASS_MAP.get(name, None)
@ -228,9 +234,9 @@ def _export_pt_1_10(g, n, *args, **kwargs):
input_float_tuples.extend(list(arg))
continue
is_inspect_activation = (
func_full_qual_name == "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation"
)
from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation
is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation)
if is_inspect_activation and isinstance(arg, str):
# _InspectActivation is a special case where the first argument is a string
# that is used to determine the activation name to be inspected.
@ -307,14 +313,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
_export = wrap_custom_export_function(_export_pt_1_10)
def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto:
"""Post process the exported model."""
if enable_custom_autograd_function:
exported_model = _post_process_enabling_autograd_function(exported_model)
return exported_model
def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto:
def post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto:
# Loop all PythonOp, append "_ctx" as the first output.
index = 0
for node in exported_model.graph.node:
@ -330,8 +329,7 @@ def _post_process_enabling_autograd_function(exported_model: ModelProto) -> Mode
op_name_prefix = kclass_name
break
if not node.name:
node.name = f"{op_name_prefix}_id_{index}"
index += 1
node.name = f"{op_name_prefix}_id_{index}"
index += 1
return exported_model

View file

@ -376,6 +376,16 @@ def call_python_backward_function(
result = backward_function(*wrapped_args)
# Extract results as DLPack tensor list.
if isinstance(result, torch.Tensor):
result = [result]
elif isinstance(result, (tuple, list)):
result = list(result)
else:
raise wrap_exception(
ORTModuleIOError,
TypeError(f"ORTModule does not support the following model output type {type(result)}."),
)
wrapped_returned_args = wrap_all_outputs(result)
torch_interop_utils.unregister_grad_fn(id(ctx))

View file

@ -19,11 +19,10 @@ from torch.utils.cpp_extension import ROCM_HOME
import onnxruntime
from onnxruntime.capi import _pybind_state as C
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from onnxruntime.training.utils import ORTModelInputOutputSchemaType
from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch
from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
from ._custom_autograd_function_exporter import _post_process_after_export
from ._fallback import (
ORTModuleDeviceException,
ORTModuleONNXModelException,
@ -141,9 +140,14 @@ class GraphExecutionManager(GraphExecutionInterface):
register_triton_op_executor()
self._zero_stage3_param_map = {}
if self._runtime_options.enable_zero_stage3_support:
# Cannot toggle feature enabling/disabling after the first time enabled.
configure_ort_compatible_zero_stage3()
from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params
self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module)
configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True)
def _get_torch_gpu_allocator_function_addresses(self):
if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available():
@ -345,7 +349,8 @@ class GraphExecutionManager(GraphExecutionInterface):
)
if os.path.exists(cache_dir) and os.path.isfile(filename):
self._logger.info(
f"Cached model detected! Cached model will be used to save export and initialization time. If you want the model to be re-exported then DELETE {filename}."
f"Cached model detected! Cached model will be used to save export and initialization time."
f"If you want the model to be re-exported then DELETE {filename}."
)
exported_model = onnx.load(filename)
return exported_model
@ -409,9 +414,24 @@ class GraphExecutionManager(GraphExecutionInterface):
)
exported_model = onnx.load_model_from_string(f.getvalue())
exported_model = _post_process_after_export(
exported_model, self._runtime_options.enable_custom_autograd_function
)
if self._runtime_options.enable_custom_autograd_function:
from ._custom_autograd_function_exporter import post_process_enabling_autograd_function
exported_model = post_process_enabling_autograd_function(exported_model)
if self._runtime_options.enable_zero_stage3_support:
from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat
exported_model = post_processing_enable_zero_stage3_compat(
exported_model,
self._zero_stage3_param_map,
[name for name, _ in self._flattened_module.named_parameters()],
)
# Cannot append pull weight trigger name to input names as following, otherwise, the later check (
# https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18)
# find input info mismatch, will re-initialize the graph builder.
# self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)
# Cache model for future runs
if cache_dir:
@ -477,7 +497,14 @@ class GraphExecutionManager(GraphExecutionInterface):
grad_builder_config = C.OrtModuleGraphBuilderConfiguration()
grad_builder_config.initializer_names = initializer_names
grad_builder_config.initializer_names_to_train = initializer_names_to_train
grad_builder_config.input_names_require_grad = self._input_info.require_grad_names
input_names_require_grad = self._input_info.require_grad_names
if self._runtime_options.enable_zero_stage3_support:
from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME
# Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)
grad_builder_config.input_names_require_grad = input_names_require_grad
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(
@ -553,6 +580,9 @@ class GraphExecutionManager(GraphExecutionInterface):
inputs, kwargs
)
if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, detected_device)
_, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers(
self._graph_initializers,
self._graph_builder.get_graph_info().user_input_names,
@ -562,6 +592,7 @@ class GraphExecutionManager(GraphExecutionInterface):
kwargs,
detected_device,
self._runtime_inspector,
self._zero_stage3_param_map,
)
# Enable sparsity-based optimization when applicable.
@ -587,6 +618,21 @@ class GraphExecutionManager(GraphExecutionInterface):
if self._runtime_options.print_memory_stat:
self._runtime_inspector.enable_memory_inspector(self._original_module)
def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device):
from ._zero_stage3_compatibility import (
STAGE3_PULL_WEIGHT_TRIGGER_NAME,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
)
kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()
return kwargs
def _log_feature_stats(self):
if get_rank() != 0:
return

View file

@ -159,6 +159,9 @@ class InferenceManager(GraphExecutionManager):
# Assert that the input and model device match
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, self._device)
prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
self._graph_initializers,
self._graph_info.user_input_names,
@ -168,6 +171,7 @@ class InferenceManager(GraphExecutionManager):
kwargs,
self._device,
self._runtime_inspector,
self._zero_stage3_param_map,
)
user_outputs, _ = InferenceManager.execution_session_run_forward(

View file

@ -168,6 +168,7 @@ def _combine_input_buffers_initializers(
kwargs: Mapping[str, ORTModelInputOutputType],
device: torch.device,
rt_inspector: RuntimeInspector,
zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]],
):
"""Creates forward `*inputs` list from user input and PyTorch initializers
@ -254,7 +255,12 @@ def _combine_input_buffers_initializers(
)
# params is a list of all initializers known to the onnx graph
result.extend(params)
if zero_stage3_offload_param_map:
for p in params:
if p not in zero_stage3_offload_param_map.values():
result.append(p)
else:
result.extend(params)
return result, embed_sparsity_results, label_sparsity_results

View file

@ -311,6 +311,9 @@ class TrainingManager(GraphExecutionManager):
self._gradient_accumulation_manager.maybe_update_cache_before_run()
if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, self._device)
prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
self._graph_initializers,
self._graph_info.user_input_names,
@ -320,6 +323,7 @@ class TrainingManager(GraphExecutionManager):
kwargs,
self._device,
self._runtime_inspector,
self._zero_stage3_param_map,
)
outputs = unflatten_user_output(

View file

@ -0,0 +1,312 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from typing import Dict, List, Optional, Tuple, Union
import torch
from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper
from onnxruntime.capi._pybind_state import register_torch_autograd_function
from onnxruntime.training.utils import pytorch_dtype_to_onnx
from ._custom_autograd_function_exporter import PythonOpShapeInferStore
from ._utils import get_fully_qualified_class_name
STAGE3_PULL_WEIGHT_TRIGGER_NAME = "pull_weight_trigger"
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE = [1]
def post_processing_enable_zero_stage3_compat(
exported_model: ModelProto,
zero_stage3_named_params: Dict[str, torch.nn.parameter.Parameter],
all_param_names: List[str],
) -> ModelProto:
"""This function is used to enable zero stage3 compatibility.
Args:
exported_model (ModelProto): The exported model.
zero_stage3_named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The offload named parameters.
all_param_names (List[str]): All parameter names.
"""
# Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3.
_register_symbolic_shape_infer_functions()
# Create weight retrieving function using zero_stage3_named_params.
func_full_qual_name = _create_weight_retrieval_function(zero_stage3_named_params)
consumer_map = {}
for node in exported_model.graph.node:
for inp in node.input:
if inp not in consumer_map:
consumer_map[inp] = []
if node not in consumer_map[inp]:
consumer_map[inp].append(node)
def _get_param_pull_trigger_name(param_name: str) -> str:
return f"pull_{param_name}"
def _get_func_name(node: NodeProto) -> Optional[str]:
for attr in node.attribute:
if attr.name == "func_name":
return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
return None
# Create weight retrieving PythonOp.
new_input, weight_pull_node = _create_weight_retrieval_pythonop(
zero_stage3_named_params,
func_full_qual_name,
STAGE3_PULL_WEIGHT_TRIGGER_NAME,
[_get_param_pull_trigger_name(pname) for pname in zero_stage3_named_params],
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
)
from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction
prefowrad_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction)
# Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction.
for graph_input in exported_model.graph.input:
if graph_input.name not in zero_stage3_named_params:
continue
if graph_input.name not in consumer_map:
continue
consumers = consumer_map[graph_input.name]
pre_forward_pythonop_node = None
for c in consumers:
if c.op_type != "PythonOp":
continue
func_name = _get_func_name(c)
if func_name == prefowrad_function_name:
assert (
pre_forward_pythonop_node is None
), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen"
pre_forward_pythonop_node = c
if pre_forward_pythonop_node is None:
raise RuntimeError(
"Fail to find ORTZeROOffloadPreForwardFunction for partitioned param: " + graph_input.name
)
index_offset_on_python_op_input = []
for i, input_name in enumerate(pre_forward_pythonop_node.input):
if input_name == graph_input.name:
index_offset_on_python_op_input.append(i)
assert (
len(index_offset_on_python_op_input) == 1
), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input}"
reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input)
new_input_name = _get_param_pull_trigger_name(graph_input.name)
pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = new_input_name
_update_python_op_input_related_attributes(
pre_forward_pythonop_node,
new_input_name,
len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE), # new rank
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, # new data type
)
output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output)
pre_forward_pythonop_node.output[output_index] = graph_input.name
# If the consumer of original `graph_input.name` is PythonOp, we need also update its attributes because now
# `graph_input.name` as output of pre_forward_pythonop_node, is full-sized parameter, the rank might differ
# from the original one.
for c in consumers:
if c == pre_forward_pythonop_node or c.op_type != "PythonOp":
continue
_update_python_op_input_related_attributes(
c,
graph_input.name,
len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank
pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type
)
# Delete exported_model.graph.input
graph_inputs_to_remove = [
graph_input for graph_input in exported_model.graph.input if graph_input.name in zero_stage3_named_params
]
for input_to_remove in graph_inputs_to_remove:
exported_model.graph.input.remove(input_to_remove)
# Re-order graph input to make sure the weight pull trigger is before all parameter inputs.
offset = 0
for graph_input in exported_model.graph.input:
if graph_input.name in all_param_names:
break
offset += 1
exported_model.graph.input.insert(offset, new_input)
exported_model.graph.node.insert(0, weight_pull_node)
return exported_model
def _create_weight_retrieval_function(
zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]]
) -> str:
"""This function is used to create a weight retrieving function using zero_stage3_named_params."""
class WeightRetrievalFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, weight_in_trigger):
params = list(zero_stage3_named_params.values())
ctx.params = params
ctx.dtype = weight_in_trigger.dtype
ctx.device = weight_in_trigger.device
ctx.shape = weight_in_trigger.shape
return (torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype),) * len(params)
@staticmethod
def backward(ctx, *grad_outputs):
return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype)
@staticmethod
def infer_shape(
node: NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
param_count = len(zero_stage3_named_params.values())
tensor_output_shapes = [
tensor_input_shapes[0],
] * param_count
tensor_output_dtypes = [
tensor_input_dtypes[0],
] * param_count
return tensor_output_shapes, tensor_output_dtypes
func_full_qual_name = get_fully_qualified_class_name(WeightRetrievalFunction)
register_torch_autograd_function(func_full_qual_name, WeightRetrievalFunction)
PythonOpShapeInferStore.register(WeightRetrievalFunction)
return func_full_qual_name
def _register_symbolic_shape_infer_functions():
"""This function is used to register symbolic shape inference functions for PythonOp used in
DeepSpeed ZeRO stage3."""
def _simple_pass_through_infer_shape(
node: NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes
PythonOpShapeInferStore.register_func(
"deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape
)
PythonOpShapeInferStore.register_func(
"deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape
)
def _linear_infer_shape(
node: NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
# output = input.matmul(weight.t())
tensor_input_shapes[0] # input
shape2 = tensor_input_shapes[1] # weight
output_shape = tensor_input_shapes[0]
output_shape[-1] = shape2[-2]
return [output_shape], [tensor_input_dtypes[0]]
PythonOpShapeInferStore.register_func(
"deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape
)
def _create_weight_retrieval_pythonop(
zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]],
func_full_qual_name: str,
input_name: str,
output_names: List[str],
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE: List[int],
) -> Tuple[ValueInfoProto, NodeProto]:
"""This function is used to create a weight retrieving PythonOp."""
offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params)
new_input = helper.make_tensor_value_info(
input_name, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE
)
output_rank_for_pull_weight_trigger = len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)
output_dtype_for_pull_weight_trigger = STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE
output_tensor_ranks = [
output_rank_for_pull_weight_trigger,
] * offload_param_count
output_tensor_types = [
output_dtype_for_pull_weight_trigger,
] * offload_param_count
node_attributes = {
"comment": "",
"inplace": 0,
"input_convention": "d",
"input_tensor_ranks": [len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)],
"input_tensor_types": [STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE],
"output_tensor_ranks": output_tensor_ranks,
"output_tensor_types": output_tensor_types,
"training_mode": 1,
"func_name": func_full_qual_name,
}
weight_pull_node = helper.make_node(
"PythonOp",
[input_name],
["pull_weight_trigger_ctx", *output_names],
"pull_weight_trigger", # node name
"PythonOp for weight retrieving.",
"com.microsoft",
**node_attributes,
)
return new_input, weight_pull_node
def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, new_rank: int, new_dtype: int):
"""This function is used to update PythonOp's input related attributes, e.g.
input_tensor_ranks and input_tensor_types.
Args:
node (NodeProto): The PythonOp node.
input_name (str): The input name to be updated.
new_rank (int): The new rank of the input, to be used in input_tensor_ranks.
new_dtype (int): The new data type of the input, to be used in input_tensor_types.
"""
input_tensor_ranks = None
input_tensor_dtypes = None
rank_attr = None
dtype_attr = None
for attr in node.attribute:
if attr.name == "input_tensor_ranks":
input_tensor_ranks = attr.ints
rank_attr = attr
if attr.name == "input_tensor_types":
input_tensor_dtypes = attr.ints
dtype_attr = attr
assert input_tensor_ranks is not None, "input_tensor_ranks is None"
assert input_tensor_dtypes is not None, "input_tensor_dtypes is None"
for index, node_input_name in enumerate(node.input):
if node_input_name == input_name:
input_tensor_ranks[index] = new_rank
input_tensor_dtypes[index] = new_dtype
node.attribute.remove(rank_attr)
node.attribute.remove(dtype_attr)
node.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks))
node.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes))

View file

@ -9,7 +9,7 @@ from onnxruntime.training.utils.torch_io_helper import (
extract_data_and_schema,
unflatten_data_using_schema,
)
from onnxruntime.training.utils.torch_type_map import pytorch_dtype_to_onnx
from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx
__all__ = [
"PrimitiveType",
@ -18,4 +18,5 @@ __all__ = [
"extract_data_and_schema",
"unflatten_data_using_schema",
"pytorch_dtype_to_onnx",
"onnx_dtype_to_pytorch",
]

View file

@ -6,6 +6,7 @@
import os
import shutil
import warnings
from io import TextIOWrapper
from pathlib import Path
from typing import List, Optional, Tuple, Union
@ -178,87 +179,97 @@ class StatisticsSubscriber(SubscriberBase):
order_file_path = step_path / "order.txt"
tensor_file_path = step_path / output_file_name
# This is to try the best effort to align the count of numbers per line for easier comparison in diff views,
# though it does not always guarantee to do this way.
torch.set_printoptions(precision=6, linewidth=128)
tensor_shape = tensor.shape
tensor_dtype = tensor.dtype
flatten_array = tensor.flatten().view(-1)
if self._run_on_cpu:
flatten_array = flatten_array.to("cpu")
if self._run_on_cpu:
num_nan = torch.isnan(flatten_array).sum()
num_inf = torch.isinf(flatten_array).sum()
num_neg = (flatten_array < 0).sum()
num_pos = (flatten_array > 0).sum()
num_zero = (flatten_array == 0).sum()
min_value = flatten_array.min()
max_value = flatten_array.max()
mean_value = flatten_array.mean()
std_value = flatten_array.std()
else:
# Split the calculation for each bucket, then do another round of calculation on the bucket results.
# This can at the best effort reduce the peak memory impact.
bucket_size = self._bucket_size
element_count = flatten_array.numel()
ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size)
nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
# Summary for each bucket
element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
for i in range(ceil_bucket_count):
end = min((i + 1) * bucket_size, element_count)
bucket = flatten_array[i * bucket_size : end]
element_count_per_bucket[i] = bucket.numel()
nan_buckets[i] = torch.isnan(bucket).sum()
inf_buckets[i] = torch.isinf(bucket).sum()
neg_buckets[i] = (bucket < 0).sum()
pos_buckets[i] = (bucket > 0).sum()
zero_buckets[i] = (bucket == 0).sum()
min_buckets[i] = bucket.min()
max_buckets[i] = bucket.max()
mean_buckets[i] = bucket.sum()
std_buckets[i] = bucket.std()
# Reduction across all buckets
num_nan = nan_buckets.sum()
num_inf = inf_buckets.sum()
num_neg = neg_buckets.sum()
num_pos = pos_buckets.sum()
num_zero = zero_buckets.sum()
min_value = min_buckets.min()
max_value = max_buckets.max()
mean_value = float(mean_buckets.sum()) / float(element_count)
# Here we refer to
# https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
# to calculate the combined standard deviation of all buckets.
s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * (
(mean_buckets - mean_value) ** 2
)
std_value = torch.sqrt(s.sum() / (element_count - 1))
with order_file_path.open(mode="a", encoding="utf-8") as f:
f.write(f"{output_file_name}\n")
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
f.write(
f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n"
f"min: {min_value} max: {max_value}, mean: {mean_value}, "
f"std: {std_value} \n"
f"nan: {num_nan}, inf: {num_inf}\n"
)
f.write(f"samples(top 128): {flatten_array[:128]}\n")
f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n")
f.write(f"{'='*16}\n")
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
def _summarize_tensor(
display_name: str,
tensor: torch.Tensor,
f: TextIOWrapper,
depth: int = 0,
run_on_cpu: bool = False,
bucket_size: int = 1024 * 1024 * 1024 // 2,
):
# This is to try the best effort to align the count of numbers per line for easier comparison in diff views,
# though it does not always guarantee to do this way.
torch.set_printoptions(precision=6, linewidth=128)
tensor_shape = tensor.shape
tensor_dtype = tensor.dtype
flatten_array = tensor.flatten().view(-1)
if run_on_cpu:
flatten_array = flatten_array.to("cpu")
if run_on_cpu:
num_nan = torch.isnan(flatten_array).sum()
num_inf = torch.isinf(flatten_array).sum()
num_neg = (flatten_array < 0).sum()
num_pos = (flatten_array > 0).sum()
num_zero = (flatten_array == 0).sum()
min_value = flatten_array.min()
max_value = flatten_array.max()
mean_value = flatten_array.mean()
std_value = flatten_array.std()
else:
# Split the calculation for each bucket, then do another round of calculation on the bucket results.
# This can at the best effort reduce the peak memory impact.
element_count = flatten_array.numel()
ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size)
nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
# Summary for each bucket
element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
for i in range(ceil_bucket_count):
end = min((i + 1) * bucket_size, element_count)
bucket = flatten_array[i * bucket_size : end]
element_count_per_bucket[i] = bucket.numel()
nan_buckets[i] = torch.isnan(bucket).sum()
inf_buckets[i] = torch.isinf(bucket).sum()
neg_buckets[i] = (bucket < 0).sum()
pos_buckets[i] = (bucket > 0).sum()
zero_buckets[i] = (bucket == 0).sum()
min_buckets[i] = bucket.min()
max_buckets[i] = bucket.max()
mean_buckets[i] = bucket.sum()
std_buckets[i] = bucket.std()
# Reduction across all buckets
num_nan = nan_buckets.sum()
num_inf = inf_buckets.sum()
num_neg = neg_buckets.sum()
num_pos = pos_buckets.sum()
num_zero = zero_buckets.sum()
min_value = min_buckets.min()
max_value = max_buckets.max()
mean_value = float(mean_buckets.sum()) / float(element_count)
# Here we refer to
# https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
# to calculate the combined standard deviation of all buckets.
s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * (
(mean_buckets - mean_value) ** 2
)
std_value = torch.sqrt(s.sum() / (element_count - 1))
f.write(
f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n"
f"min: {min_value} max: {max_value}, mean: {mean_value}, "
f"std: {std_value} \n"
f"nan: {num_nan}, inf: {num_inf}\n"
)
f.write(f"samples(top 128): {flatten_array[:128]}\n")
f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n")
f.write(f"{'='*16}\n")

View file

@ -29,14 +29,6 @@ def no_increase_global_step():
finally:
ORT_NO_INCREASE_GLOBAL_STEP[0] = False
@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes
class _IncrementStep(torch.autograd.Function):
"""This class is used to manage the global execution step, e.g.
@ -55,8 +47,9 @@ class _IncrementStep(torch.autograd.Function):
ctx.current_step = run_ctx.global_states.execution_step
ctx.run_ctx = run_ctx
if ctx.current_step >= 0:
print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}")
# Uncomment the following line for debugging purposes.
# if ctx.current_step >= 0:
# print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}")
if ORT_NO_INCREASE_GLOBAL_STEP[0] is False:
ctx.run_ctx.global_states.execution_step += 1
@ -191,7 +184,7 @@ class SubscriberManager:
next_module_index: list of int, carrying a global unique module index that can be used next.
"""
module_index = next_module_index[0]
module.id = module_index # STAGE3WARN: needed by DeepSpeed
module.id = module_index # STAGE3WARN#1: needed by DeepSpeed
self._run_ctx.global_states.module_index_to_depth[module_index] = depth
self._run_ctx.global_states.module_to_module_index[module] = module_index
@ -217,7 +210,7 @@ class SubscriberManager:
next_module_index: list of int, carrying a global unique module index that can be used next.
"""
module_index = next_module_index[0]
module.id = module_index # STAGE3WARN: needed by DeepSpeed
module.id = module_index # STAGE3WARN#2: needed by DeepSpeed
self._run_ctx.global_states.module_index_to_depth[module_index] = depth
self._run_ctx.global_states.module_to_module_index[module] = module_index

View file

@ -23,25 +23,37 @@ from onnxruntime.training.utils import (
from ._subscriber_base import RuntimeStates, SubscriberBase
# Used to monkey patch the original function
# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333
def _setup_zero_stage3_ort_compatible_hooks(self):
self.hierarchy = 0
def _get_ort_compatible_zero_stage3_hook_function(debug, stats_output_dir, stats_overwrite):
"""Create ort compatible hook function for DeepSpeed ZeRO stage3.
from onnxruntime.training.utils.hooks import SubscriberManager, ZeROOffloadSubscriber
from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer
Args:
debug: whether to enable convergence debugging.
stats_output_dir: the directory to store convergence stats.
stats_overwrite: whether to overwrite the stats file if it already exists.
"""
# Each DeepSpeed engine has a separate subscriber manager.
self._offload_subscriber_manager = SubscriberManager()
self._offload_subscriber_manager.subscribe(
self.module, [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)]
)
self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks)
self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks)
# Used to monkey patch the original function
# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333
def _setup_zero_stage3_ort_compatible_hooks(self):
self.hierarchy = 0
# Add top module to stack trace
global FWD_MODULE_STACK # noqa: PLW0602
FWD_MODULE_STACK.append(self.module)
from onnxruntime.training.utils.hooks import StatisticsSubscriber, SubscriberManager, ZeROOffloadSubscriber
from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer
subscribers = [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)]
if debug is True:
subscribers.append(StatisticsSubscriber(output_dir=stats_output_dir, override_output_dir=stats_overwrite))
# Each DeepSpeed engine has a separate subscriber manager.
self._offload_subscriber_manager = SubscriberManager()
self._offload_subscriber_manager.subscribe(self.module, subscribers)
self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks)
self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks)
# Add top module to stack trace
global FWD_MODULE_STACK # noqa: PLW0602
FWD_MODULE_STACK.append(self.module)
return _setup_zero_stage3_ort_compatible_hooks
# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104
@ -86,14 +98,16 @@ try:
_zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload.setup_zero_stage3_hooks)
# This is the function to enable ORT ZeRO offload.
def configure_ort_compatible_zero_stage3():
def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", stats_overwrite=False):
"""Configure ZeRO stage3 to be ORT compatible.
This function will overwrite the original DeepSpeed ZeRO stage3 hooks to make it ORT compatible.
"""
# Only done once no matter how many times this function is called for different modules.
DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _setup_zero_stage3_ort_compatible_hooks
DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _get_ort_compatible_zero_stage3_hook_function(
debug, stats_output_dir, stats_overwrite
)
from deepspeed.runtime.zero.linear import zero3_linear_wrap
@ -103,7 +117,7 @@ try:
except ImportError as e:
warnings.warn(f"DeepSpeed import error {e}")
def configure_ort_compatible_zero_stage3():
def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False):
raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.")
@ -115,13 +129,13 @@ def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.par
"""
from deepspeed.runtime.zero.partitioned_param_coordinator import iter_params
# Retrive the parameters that are not available for this module.
# Retrieve all parameters for this module.
partitioned_params = [param for param in iter_params(module)]
return partitioned_params
def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]:
def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]:
"""Retrieve all the parameters that are offloaded."""
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
@ -134,16 +148,13 @@ def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.par
class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
"""This function is a common bridge to call original PyTorch's
pre_forward_function and post_backward_function.
"""
"""This function is a common bridge to call original PyTorch's pre_forward_function"""
@staticmethod
def forward(
ctx,
module,
pre_forward_with_kwargs_function,
post_backward_function,
args_schema,
kwargs_schema,
args_tensor_count,
@ -155,7 +166,6 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
ctx: context object
module: the module to be called
pre_forward_with_kwargs_function: the function to be called before forward (PyTorch's pre_forward_function)
post_backward_function: the function to be called after backward (PyTorch's post_backward_function)
args_schema: the schema of the args, used to reconstruct the args in original form in
PyTorch's pre_forward_function's inputs.
kwargs_schema: the schema of the kwargs, used to reconstruct the kwargs in original form in
@ -168,6 +178,17 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
args_tensors = tensor_list[:args_tensor_count]
kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count]
# For PyTorch runs, the sizes are all 0, it does not need a gradient because
# param._detach().requires_grad_(False) is called.
# But for ORT runs, the sizes are all [1], as output of weight retrieval function.
# So we keep track of the shapes and dtypes of the passed-in tensors, then generate the grads in backward.
# While for both PyTorch and ORT runs, the grad is not important because they are not param grads
# anymore, they are only used for completing the full backward propagation.
passed_in_param_tensors = tensor_list[args_tensor_count + kwargs_tensor_count :]
ctx.shapes = [p.shape for p in passed_in_param_tensors]
ctx.dtypes = [p.dtype for p in passed_in_param_tensors]
ctx.devices = [p.device for p in passed_in_param_tensors]
args = unflatten_data_using_schema(args_tensors, args_schema)
kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema)
@ -179,6 +200,8 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
partitioned_params = _get_params_for_current_module(module)
ctx.partitioned_params = partitioned_params
assert len(partitioned_params) == len(passed_in_param_tensors)
f_ret = pre_forward_with_kwargs_function(module, args, kwargs)
if f_ret is None:
@ -188,7 +211,6 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
updated_args, updated_kwargs = f_ret
ctx.module = module
ctx.post_backward_function = post_backward_function
updated_args_tensors, _ = extract_data_and_schema(updated_args)
updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs)
@ -203,17 +225,32 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *grads):
updated_grads = grads
if ctx.post_backward_function is not None:
ret = ctx.post_backward_function(ctx.module, grads)
if ret is not None:
updated_grads = ret
# TODO(pengwa) Update grad for partitioned parameters.
input_count = len(updated_grads) - len(ctx.partitioned_params)
zeros = [torch.zeros(0, dtype=p.dtype, device=p.device) for p in ctx.partitioned_params]
zero_grads = updated_grads[:input_count] + tuple(zeros)
param_start_offset = input_count
return (None, None, None, None, None, None, None, *zero_grads)
# Only need to accumulate grad explicitly for ORT run (e.g. ctx.shapes[0] == (1,));
# In the PyTorch run, the accumulation happens automatically.
need_manual_grad_acc = len(ctx.shapes) > 0 and ctx.shapes[0] == (1,)
if need_manual_grad_acc:
for param_index, p in enumerate(ctx.partitioned_params):
g = updated_grads[param_index + param_start_offset]
if g is None:
raise RuntimeError(f"param {p} has no grad, this should not happen.")
# Param gradient accumulation is triggered here, along with the attached hooks, done by PyTorch.
assert p.shape == g.shape, f"param_index: {param_index} - param shape {p.shape} != grad shape {g.shape}"
p.backward(g)
# At this point, the **real** param grads are already updated, the following grads are only used for
# completing the full backward propagation, will not affect parameter updates.
passed_in_param_grad = [
torch.zeros(shape, dtype=dtype, device=device)
for shape, dtype, device in zip(ctx.shapes, ctx.dtypes, ctx.devices)
]
zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad)
return (None, None, None, None, None, None, *zero_grads)
@staticmethod
def infer_shape(
@ -258,14 +295,14 @@ class ORTZeROOffloadPostForwardFunction(torch.autograd.Function):
module: the module to be called
post_forward_function: the function to be called after forward (PyTorch's post_forward_function)
pre_backward_function: the function to be called before backward (PyTorch's pre_backward_function)
output_schema: the schema of the output, used to reconstruct the output in original form in
output_schema: the schema of the output, used to reconstruct the output in its original form in
PyTorch's post_forward_function's inputs.
output_tensors: the list of tensors.
"""
outputs = unflatten_data_using_schema(output_tensors, output_schema)
# STAGE3WARN: _post_forward_module_hook's second argument `input is not used, so we just pass a None here.
# STAGE3WARN#3: _post_forward_module_hook's second argument `input is not used, so we just pass a None here.
updated_outputs = post_forward_function(module, None, outputs)
if updated_outputs is None:
@ -341,12 +378,20 @@ class ZeROOffloadSubscriber(SubscriberBase):
input and output for torch.autograd.Function, so we do flatten and unflatten here.
"""
## Handle `_post_backward_module_hook`
args_tensors, args_schema = extract_data_and_schema(args)
# Put `_post_backward_module_hook` first because in backward, it is responsible for unloading parameters,
# we want ORTZeROOffloadPreForwardFunction's backward still be able to access the full sized parameters.
_post_backward_module_hook = self._functions.get("_post_backward_module_hook")
# STAGE3WARN#4: most logic in _post_backward_module_hook can be traced correctly so we don't need to
# wrap with PythonOp. For those cannot be traced, we handle them in STAGE3WARN#5.
updated_args = _post_backward_module_hook(module, args)
## Handle `_pre_forward_module_hook`
args_tensors, args_schema = extract_data_and_schema(updated_args)
kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs)
partitioned_params = _get_params_for_current_module(module)
_pre_forward_module_hook = self._functions.get("_pre_forward_module_hook")
args_tensor_count = len(args_tensors)
@ -358,18 +403,29 @@ class ZeROOffloadSubscriber(SubscriberBase):
if rets is not None:
updated_args = rets
# STAGE3WARN: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration.
# STAGE3WARN#5: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration.
module.ds_grads_remaining = 0
return updated_args, updated_kwargs
all_tensors = args_tensors + kwargs_tensors + partitioned_params
# Need to pass the parameters as input to let the exporter trace the related weights for
# current ORTZeROOffloadPreForwardFunction
partitioned_params = _get_params_for_current_module(module)
# Don't require grad for passed-in parameter, otherwise it will be treated as a leaf node, in backward
# returned 0-sized grad did not match the param's gradient accumulator function's input shape metadata,
# PyTorch run will fail during backward.
# This will not harm parameter gradient build either in ORT or PyTorch, imagine the weights are used by
# computation anyway, so the gradient will be built. This hook only references the parameter, but won't
# generate a gradient path for it.
detached_partitioned_params = [p.detach().requires_grad_(False) for p in partitioned_params]
all_tensors = args_tensors + kwargs_tensors + detached_partitioned_params
self._check_all_tensor(all_tensors, module, "pre_forward_module_apply_impl input check")
rets = ORTZeROOffloadPreForwardFunction.apply(
module,
_wrap_pre_forward_module_hook,
None,
args_schema,
kwargs_schema,
args_tensor_count,
@ -385,11 +441,6 @@ class ZeROOffloadSubscriber(SubscriberBase):
updated_args = unflatten_data_using_schema(updated_args_tensors, args_schema)
updated_kwargs = unflatten_data_using_schema(updated_kwargs_tensors, kwargs_schema)
_post_backward_module_hook = self._functions.get("_post_backward_module_hook")
# STAGE3WARN: Other part of _post_backward_module_hook can be traced correctly so we don't need to
# wrap with PythonOp.
updated_args = _post_backward_module_hook(module, updated_args)
return updated_args, updated_kwargs
def post_forward_module_apply_impl(
@ -411,7 +462,7 @@ class ZeROOffloadSubscriber(SubscriberBase):
_post_forward_module_hook = self._functions.get("_post_forward_module_hook")
def _wrap_post_forward_module_hook(module, input, outputs):
# STAGE3WARN: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here.
# STAGE3WARN#6: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here.
from deepspeed.runtime.zero.partition_parameters import is_zero_param
updated_outputs = _post_forward_module_hook(module, input, outputs)
@ -438,8 +489,8 @@ class ZeROOffloadSubscriber(SubscriberBase):
updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema)
_pre_backward_module_hook = self._functions.get("_pre_backward_module_hook")
# STAGE3WARN: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here.
# STAGE3WARN: part of the original _pre_backward_module_hook can be traced correctly so we moved them into
# STAGE3WARN#7: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here.
# STAGE3WARN#8: part of the original _pre_backward_module_hook can be traced correctly so we moved them into
# _wrap_post_forward_module_hook above.
updated_outputs = _pre_backward_module_hook(module, None, updated_outputs)

View file

@ -33,6 +33,8 @@ _CAST_PYTORCH_TO_ONNX = {
_DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _CAST_PYTORCH_TO_ONNX.items()}
_ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()}
def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType:
"""Converts a pytorch dtype or scalar type string to an onnx dtype."""
@ -45,3 +47,10 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc
if dtype not in _DTYPE_TO_ONNX:
raise RuntimeError(f"Unsupported dtype {dtype}")
return _DTYPE_TO_ONNX[dtype]
def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype:
"""Converts an onnx dtype to a pytorch dtype."""
if dtype not in _ONNX_TO_DTYPE:
raise RuntimeError(f"Unsupported dtype {dtype}")
return _ONNX_TO_DTYPE[dtype]

View file

@ -153,8 +153,11 @@ void PythonOpBase::RunForward(OpKernelContext* context,
inplace_ != 0,
kernel_invoke_id_);
ORT_ENFORCE(1 + returned_ortvalues.size() == static_cast<size_t>(context->OutputCount()),
"Output count mismatch for PythonOp run");
const size_t returned_output_count = 1 + returned_ortvalues.size();
const size_t kernel_output_count = static_cast<size_t>(context->OutputCount());
ORT_ENFORCE(returned_output_count == kernel_output_count, "Output count mismatch for PythonOp run, ",
"returned_output_count: ", returned_output_count, ", expected kernel_output_count: ",
kernel_output_count);
}
void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vector<OrtValue>& returned_args) const {