mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-22 02:30:26 +00:00
Model post process for zero stage3 training (#17187)
### Model post process for zero stage3 training This is the last change to make single GPU/Multiple GPUs run pass. Design details: https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9 `PyTorch` runs with ZeROOffloadSubscriber: ``` model = prepare_model(...) from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 configure_ort_compatible_zero_stage3() ``` `ORTModule` runs with ZeROOffloadSubscriber: ``` os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1' from onnxruntime.training.ortmodule import ORTModule model = ORTModule(self.model) ``` It will be fairly easy to debug convergence issue if both ORT and PyTorch can run the same offload path. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
498b60d8a4
commit
6b7bce5ec9
13 changed files with 618 additions and 170 deletions
|
|
@ -28,7 +28,8 @@ class PythonOpShapeInferStore:
|
|||
|
||||
@classmethod
|
||||
def register(cls, kclass: torch.autograd.Function) -> None:
|
||||
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined.
|
||||
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod
|
||||
"infer_shape" defined.
|
||||
|
||||
The signature of the shape inference function should be:
|
||||
@staticmethod
|
||||
|
|
@ -51,6 +52,11 @@ class PythonOpShapeInferStore:
|
|||
if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP:
|
||||
cls._CLASS_MAP[kclass_name] = kclass.infer_shape
|
||||
|
||||
@classmethod
|
||||
def register_func(cls, name: str, func: Callable) -> None:
|
||||
"""Register a shape inference function for a torch.autograd.Function by name."""
|
||||
cls._CLASS_MAP[name] = func
|
||||
|
||||
@classmethod
|
||||
def get_shape_infer(cls, name: str) -> Optional[Callable]:
|
||||
return cls._CLASS_MAP.get(name, None)
|
||||
|
|
@ -228,9 +234,9 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
input_float_tuples.extend(list(arg))
|
||||
continue
|
||||
|
||||
is_inspect_activation = (
|
||||
func_full_qual_name == "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation"
|
||||
)
|
||||
from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation
|
||||
|
||||
is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation)
|
||||
if is_inspect_activation and isinstance(arg, str):
|
||||
# _InspectActivation is a special case where the first argument is a string
|
||||
# that is used to determine the activation name to be inspected.
|
||||
|
|
@ -307,14 +313,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
_export = wrap_custom_export_function(_export_pt_1_10)
|
||||
|
||||
|
||||
def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto:
|
||||
"""Post process the exported model."""
|
||||
if enable_custom_autograd_function:
|
||||
exported_model = _post_process_enabling_autograd_function(exported_model)
|
||||
return exported_model
|
||||
|
||||
|
||||
def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto:
|
||||
def post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto:
|
||||
# Loop all PythonOp, append "_ctx" as the first output.
|
||||
index = 0
|
||||
for node in exported_model.graph.node:
|
||||
|
|
@ -330,8 +329,7 @@ def _post_process_enabling_autograd_function(exported_model: ModelProto) -> Mode
|
|||
op_name_prefix = kclass_name
|
||||
break
|
||||
|
||||
if not node.name:
|
||||
node.name = f"{op_name_prefix}_id_{index}"
|
||||
index += 1
|
||||
node.name = f"{op_name_prefix}_id_{index}"
|
||||
index += 1
|
||||
|
||||
return exported_model
|
||||
|
|
|
|||
|
|
@ -376,6 +376,16 @@ def call_python_backward_function(
|
|||
result = backward_function(*wrapped_args)
|
||||
|
||||
# Extract results as DLPack tensor list.
|
||||
if isinstance(result, torch.Tensor):
|
||||
result = [result]
|
||||
elif isinstance(result, (tuple, list)):
|
||||
result = list(result)
|
||||
else:
|
||||
raise wrap_exception(
|
||||
ORTModuleIOError,
|
||||
TypeError(f"ORTModule does not support the following model output type {type(result)}."),
|
||||
)
|
||||
|
||||
wrapped_returned_args = wrap_all_outputs(result)
|
||||
|
||||
torch_interop_utils.unregister_grad_fn(id(ctx))
|
||||
|
|
|
|||
|
|
@ -19,11 +19,10 @@ from torch.utils.cpp_extension import ROCM_HOME
|
|||
import onnxruntime
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
from onnxruntime.training.utils import ORTModelInputOutputSchemaType
|
||||
from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch
|
||||
from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
|
||||
|
||||
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
|
||||
from ._custom_autograd_function_exporter import _post_process_after_export
|
||||
from ._fallback import (
|
||||
ORTModuleDeviceException,
|
||||
ORTModuleONNXModelException,
|
||||
|
|
@ -141,9 +140,14 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
register_triton_op_executor()
|
||||
|
||||
self._zero_stage3_param_map = {}
|
||||
if self._runtime_options.enable_zero_stage3_support:
|
||||
# Cannot toggle feature enabling/disabling after the first time enabled.
|
||||
configure_ort_compatible_zero_stage3()
|
||||
from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params
|
||||
|
||||
self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module)
|
||||
|
||||
configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True)
|
||||
|
||||
def _get_torch_gpu_allocator_function_addresses(self):
|
||||
if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available():
|
||||
|
|
@ -345,7 +349,8 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
)
|
||||
if os.path.exists(cache_dir) and os.path.isfile(filename):
|
||||
self._logger.info(
|
||||
f"Cached model detected! Cached model will be used to save export and initialization time. If you want the model to be re-exported then DELETE {filename}."
|
||||
f"Cached model detected! Cached model will be used to save export and initialization time."
|
||||
f"If you want the model to be re-exported then DELETE {filename}."
|
||||
)
|
||||
exported_model = onnx.load(filename)
|
||||
return exported_model
|
||||
|
|
@ -409,9 +414,24 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
)
|
||||
exported_model = onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
exported_model = _post_process_after_export(
|
||||
exported_model, self._runtime_options.enable_custom_autograd_function
|
||||
)
|
||||
if self._runtime_options.enable_custom_autograd_function:
|
||||
from ._custom_autograd_function_exporter import post_process_enabling_autograd_function
|
||||
|
||||
exported_model = post_process_enabling_autograd_function(exported_model)
|
||||
|
||||
if self._runtime_options.enable_zero_stage3_support:
|
||||
from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat
|
||||
|
||||
exported_model = post_processing_enable_zero_stage3_compat(
|
||||
exported_model,
|
||||
self._zero_stage3_param_map,
|
||||
[name for name, _ in self._flattened_module.named_parameters()],
|
||||
)
|
||||
|
||||
# Cannot append pull weight trigger name to input names as following, otherwise, the later check (
|
||||
# https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18)
|
||||
# find input info mismatch, will re-initialize the graph builder.
|
||||
# self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)
|
||||
|
||||
# Cache model for future runs
|
||||
if cache_dir:
|
||||
|
|
@ -477,7 +497,14 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
grad_builder_config = C.OrtModuleGraphBuilderConfiguration()
|
||||
grad_builder_config.initializer_names = initializer_names
|
||||
grad_builder_config.initializer_names_to_train = initializer_names_to_train
|
||||
grad_builder_config.input_names_require_grad = self._input_info.require_grad_names
|
||||
|
||||
input_names_require_grad = self._input_info.require_grad_names
|
||||
if self._runtime_options.enable_zero_stage3_support:
|
||||
from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME
|
||||
|
||||
# Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph.
|
||||
input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)
|
||||
grad_builder_config.input_names_require_grad = input_names_require_grad
|
||||
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
|
||||
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
|
||||
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(
|
||||
|
|
@ -553,6 +580,9 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
inputs, kwargs
|
||||
)
|
||||
|
||||
if self._runtime_options.enable_zero_stage3_support:
|
||||
self._append_pull_weight_trigger_as_input(kwargs, detected_device)
|
||||
|
||||
_, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers(
|
||||
self._graph_initializers,
|
||||
self._graph_builder.get_graph_info().user_input_names,
|
||||
|
|
@ -562,6 +592,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
kwargs,
|
||||
detected_device,
|
||||
self._runtime_inspector,
|
||||
self._zero_stage3_param_map,
|
||||
)
|
||||
|
||||
# Enable sparsity-based optimization when applicable.
|
||||
|
|
@ -587,6 +618,21 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
if self._runtime_options.print_memory_stat:
|
||||
self._runtime_inspector.enable_memory_inspector(self._original_module)
|
||||
|
||||
def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device):
|
||||
from ._zero_stage3_compatibility import (
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_NAME,
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
|
||||
)
|
||||
|
||||
kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
|
||||
dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
|
||||
device=device,
|
||||
).requires_grad_()
|
||||
|
||||
return kwargs
|
||||
|
||||
def _log_feature_stats(self):
|
||||
if get_rank() != 0:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -159,6 +159,9 @@ class InferenceManager(GraphExecutionManager):
|
|||
# Assert that the input and model device match
|
||||
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
|
||||
|
||||
if self._runtime_options.enable_zero_stage3_support:
|
||||
self._append_pull_weight_trigger_as_input(kwargs, self._device)
|
||||
|
||||
prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
|
||||
self._graph_initializers,
|
||||
self._graph_info.user_input_names,
|
||||
|
|
@ -168,6 +171,7 @@ class InferenceManager(GraphExecutionManager):
|
|||
kwargs,
|
||||
self._device,
|
||||
self._runtime_inspector,
|
||||
self._zero_stage3_param_map,
|
||||
)
|
||||
|
||||
user_outputs, _ = InferenceManager.execution_session_run_forward(
|
||||
|
|
|
|||
|
|
@ -168,6 +168,7 @@ def _combine_input_buffers_initializers(
|
|||
kwargs: Mapping[str, ORTModelInputOutputType],
|
||||
device: torch.device,
|
||||
rt_inspector: RuntimeInspector,
|
||||
zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]],
|
||||
):
|
||||
"""Creates forward `*inputs` list from user input and PyTorch initializers
|
||||
|
||||
|
|
@ -254,7 +255,12 @@ def _combine_input_buffers_initializers(
|
|||
)
|
||||
|
||||
# params is a list of all initializers known to the onnx graph
|
||||
result.extend(params)
|
||||
if zero_stage3_offload_param_map:
|
||||
for p in params:
|
||||
if p not in zero_stage3_offload_param_map.values():
|
||||
result.append(p)
|
||||
else:
|
||||
result.extend(params)
|
||||
|
||||
return result, embed_sparsity_results, label_sparsity_results
|
||||
|
||||
|
|
|
|||
|
|
@ -311,6 +311,9 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
self._gradient_accumulation_manager.maybe_update_cache_before_run()
|
||||
|
||||
if self._runtime_options.enable_zero_stage3_support:
|
||||
self._append_pull_weight_trigger_as_input(kwargs, self._device)
|
||||
|
||||
prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
|
||||
self._graph_initializers,
|
||||
self._graph_info.user_input_names,
|
||||
|
|
@ -320,6 +323,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
kwargs,
|
||||
self._device,
|
||||
self._runtime_inspector,
|
||||
self._zero_stage3_param_map,
|
||||
)
|
||||
|
||||
outputs = unflatten_user_output(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,312 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper
|
||||
|
||||
from onnxruntime.capi._pybind_state import register_torch_autograd_function
|
||||
from onnxruntime.training.utils import pytorch_dtype_to_onnx
|
||||
|
||||
from ._custom_autograd_function_exporter import PythonOpShapeInferStore
|
||||
from ._utils import get_fully_qualified_class_name
|
||||
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_NAME = "pull_weight_trigger"
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE = [1]
|
||||
|
||||
|
||||
def post_processing_enable_zero_stage3_compat(
|
||||
exported_model: ModelProto,
|
||||
zero_stage3_named_params: Dict[str, torch.nn.parameter.Parameter],
|
||||
all_param_names: List[str],
|
||||
) -> ModelProto:
|
||||
"""This function is used to enable zero stage3 compatibility.
|
||||
|
||||
Args:
|
||||
exported_model (ModelProto): The exported model.
|
||||
zero_stage3_named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The offload named parameters.
|
||||
all_param_names (List[str]): All parameter names.
|
||||
"""
|
||||
|
||||
# Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3.
|
||||
_register_symbolic_shape_infer_functions()
|
||||
|
||||
# Create weight retrieving function using zero_stage3_named_params.
|
||||
func_full_qual_name = _create_weight_retrieval_function(zero_stage3_named_params)
|
||||
|
||||
consumer_map = {}
|
||||
for node in exported_model.graph.node:
|
||||
for inp in node.input:
|
||||
if inp not in consumer_map:
|
||||
consumer_map[inp] = []
|
||||
|
||||
if node not in consumer_map[inp]:
|
||||
consumer_map[inp].append(node)
|
||||
|
||||
def _get_param_pull_trigger_name(param_name: str) -> str:
|
||||
return f"pull_{param_name}"
|
||||
|
||||
def _get_func_name(node: NodeProto) -> Optional[str]:
|
||||
for attr in node.attribute:
|
||||
if attr.name == "func_name":
|
||||
return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
|
||||
return None
|
||||
|
||||
# Create weight retrieving PythonOp.
|
||||
new_input, weight_pull_node = _create_weight_retrieval_pythonop(
|
||||
zero_stage3_named_params,
|
||||
func_full_qual_name,
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_NAME,
|
||||
[_get_param_pull_trigger_name(pname) for pname in zero_stage3_named_params],
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
|
||||
)
|
||||
|
||||
from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction
|
||||
|
||||
prefowrad_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction)
|
||||
|
||||
# Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction.
|
||||
for graph_input in exported_model.graph.input:
|
||||
if graph_input.name not in zero_stage3_named_params:
|
||||
continue
|
||||
|
||||
if graph_input.name not in consumer_map:
|
||||
continue
|
||||
|
||||
consumers = consumer_map[graph_input.name]
|
||||
pre_forward_pythonop_node = None
|
||||
|
||||
for c in consumers:
|
||||
if c.op_type != "PythonOp":
|
||||
continue
|
||||
|
||||
func_name = _get_func_name(c)
|
||||
if func_name == prefowrad_function_name:
|
||||
assert (
|
||||
pre_forward_pythonop_node is None
|
||||
), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen"
|
||||
pre_forward_pythonop_node = c
|
||||
|
||||
if pre_forward_pythonop_node is None:
|
||||
raise RuntimeError(
|
||||
"Fail to find ORTZeROOffloadPreForwardFunction for partitioned param: " + graph_input.name
|
||||
)
|
||||
|
||||
index_offset_on_python_op_input = []
|
||||
for i, input_name in enumerate(pre_forward_pythonop_node.input):
|
||||
if input_name == graph_input.name:
|
||||
index_offset_on_python_op_input.append(i)
|
||||
|
||||
assert (
|
||||
len(index_offset_on_python_op_input) == 1
|
||||
), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input}"
|
||||
|
||||
reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input)
|
||||
new_input_name = _get_param_pull_trigger_name(graph_input.name)
|
||||
pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = new_input_name
|
||||
|
||||
_update_python_op_input_related_attributes(
|
||||
pre_forward_pythonop_node,
|
||||
new_input_name,
|
||||
len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE), # new rank
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, # new data type
|
||||
)
|
||||
|
||||
output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output)
|
||||
pre_forward_pythonop_node.output[output_index] = graph_input.name
|
||||
|
||||
# If the consumer of original `graph_input.name` is PythonOp, we need also update its attributes because now
|
||||
# `graph_input.name` as output of pre_forward_pythonop_node, is full-sized parameter, the rank might differ
|
||||
# from the original one.
|
||||
for c in consumers:
|
||||
if c == pre_forward_pythonop_node or c.op_type != "PythonOp":
|
||||
continue
|
||||
_update_python_op_input_related_attributes(
|
||||
c,
|
||||
graph_input.name,
|
||||
len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank
|
||||
pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type
|
||||
)
|
||||
|
||||
# Delete exported_model.graph.input
|
||||
graph_inputs_to_remove = [
|
||||
graph_input for graph_input in exported_model.graph.input if graph_input.name in zero_stage3_named_params
|
||||
]
|
||||
for input_to_remove in graph_inputs_to_remove:
|
||||
exported_model.graph.input.remove(input_to_remove)
|
||||
|
||||
# Re-order graph input to make sure the weight pull trigger is before all parameter inputs.
|
||||
offset = 0
|
||||
for graph_input in exported_model.graph.input:
|
||||
if graph_input.name in all_param_names:
|
||||
break
|
||||
offset += 1
|
||||
|
||||
exported_model.graph.input.insert(offset, new_input)
|
||||
exported_model.graph.node.insert(0, weight_pull_node)
|
||||
|
||||
return exported_model
|
||||
|
||||
|
||||
def _create_weight_retrieval_function(
|
||||
zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]]
|
||||
) -> str:
|
||||
"""This function is used to create a weight retrieving function using zero_stage3_named_params."""
|
||||
|
||||
class WeightRetrievalFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, weight_in_trigger):
|
||||
params = list(zero_stage3_named_params.values())
|
||||
ctx.params = params
|
||||
ctx.dtype = weight_in_trigger.dtype
|
||||
ctx.device = weight_in_trigger.device
|
||||
ctx.shape = weight_in_trigger.shape
|
||||
return (torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype),) * len(params)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype)
|
||||
|
||||
@staticmethod
|
||||
def infer_shape(
|
||||
node: NodeProto,
|
||||
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
|
||||
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
|
||||
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
|
||||
param_count = len(zero_stage3_named_params.values())
|
||||
tensor_output_shapes = [
|
||||
tensor_input_shapes[0],
|
||||
] * param_count
|
||||
tensor_output_dtypes = [
|
||||
tensor_input_dtypes[0],
|
||||
] * param_count
|
||||
return tensor_output_shapes, tensor_output_dtypes
|
||||
|
||||
func_full_qual_name = get_fully_qualified_class_name(WeightRetrievalFunction)
|
||||
register_torch_autograd_function(func_full_qual_name, WeightRetrievalFunction)
|
||||
PythonOpShapeInferStore.register(WeightRetrievalFunction)
|
||||
|
||||
return func_full_qual_name
|
||||
|
||||
|
||||
def _register_symbolic_shape_infer_functions():
|
||||
"""This function is used to register symbolic shape inference functions for PythonOp used in
|
||||
DeepSpeed ZeRO stage3."""
|
||||
|
||||
def _simple_pass_through_infer_shape(
|
||||
node: NodeProto,
|
||||
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
|
||||
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
|
||||
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
|
||||
return tensor_input_shapes, tensor_input_dtypes
|
||||
|
||||
PythonOpShapeInferStore.register_func(
|
||||
"deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape
|
||||
)
|
||||
PythonOpShapeInferStore.register_func(
|
||||
"deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape
|
||||
)
|
||||
|
||||
def _linear_infer_shape(
|
||||
node: NodeProto,
|
||||
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
|
||||
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
|
||||
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
|
||||
# output = input.matmul(weight.t())
|
||||
tensor_input_shapes[0] # input
|
||||
shape2 = tensor_input_shapes[1] # weight
|
||||
output_shape = tensor_input_shapes[0]
|
||||
output_shape[-1] = shape2[-2]
|
||||
return [output_shape], [tensor_input_dtypes[0]]
|
||||
|
||||
PythonOpShapeInferStore.register_func(
|
||||
"deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape
|
||||
)
|
||||
|
||||
|
||||
def _create_weight_retrieval_pythonop(
|
||||
zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]],
|
||||
func_full_qual_name: str,
|
||||
input_name: str,
|
||||
output_names: List[str],
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE: List[int],
|
||||
) -> Tuple[ValueInfoProto, NodeProto]:
|
||||
"""This function is used to create a weight retrieving PythonOp."""
|
||||
offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params)
|
||||
new_input = helper.make_tensor_value_info(
|
||||
input_name, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE
|
||||
)
|
||||
output_rank_for_pull_weight_trigger = len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)
|
||||
output_dtype_for_pull_weight_trigger = STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE
|
||||
output_tensor_ranks = [
|
||||
output_rank_for_pull_weight_trigger,
|
||||
] * offload_param_count
|
||||
output_tensor_types = [
|
||||
output_dtype_for_pull_weight_trigger,
|
||||
] * offload_param_count
|
||||
|
||||
node_attributes = {
|
||||
"comment": "",
|
||||
"inplace": 0,
|
||||
"input_convention": "d",
|
||||
"input_tensor_ranks": [len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)],
|
||||
"input_tensor_types": [STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE],
|
||||
"output_tensor_ranks": output_tensor_ranks,
|
||||
"output_tensor_types": output_tensor_types,
|
||||
"training_mode": 1,
|
||||
"func_name": func_full_qual_name,
|
||||
}
|
||||
|
||||
weight_pull_node = helper.make_node(
|
||||
"PythonOp",
|
||||
[input_name],
|
||||
["pull_weight_trigger_ctx", *output_names],
|
||||
"pull_weight_trigger", # node name
|
||||
"PythonOp for weight retrieving.",
|
||||
"com.microsoft",
|
||||
**node_attributes,
|
||||
)
|
||||
|
||||
return new_input, weight_pull_node
|
||||
|
||||
|
||||
def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, new_rank: int, new_dtype: int):
|
||||
"""This function is used to update PythonOp's input related attributes, e.g.
|
||||
input_tensor_ranks and input_tensor_types.
|
||||
|
||||
Args:
|
||||
node (NodeProto): The PythonOp node.
|
||||
input_name (str): The input name to be updated.
|
||||
new_rank (int): The new rank of the input, to be used in input_tensor_ranks.
|
||||
new_dtype (int): The new data type of the input, to be used in input_tensor_types.
|
||||
"""
|
||||
input_tensor_ranks = None
|
||||
input_tensor_dtypes = None
|
||||
rank_attr = None
|
||||
dtype_attr = None
|
||||
for attr in node.attribute:
|
||||
if attr.name == "input_tensor_ranks":
|
||||
input_tensor_ranks = attr.ints
|
||||
rank_attr = attr
|
||||
if attr.name == "input_tensor_types":
|
||||
input_tensor_dtypes = attr.ints
|
||||
dtype_attr = attr
|
||||
|
||||
assert input_tensor_ranks is not None, "input_tensor_ranks is None"
|
||||
assert input_tensor_dtypes is not None, "input_tensor_dtypes is None"
|
||||
|
||||
for index, node_input_name in enumerate(node.input):
|
||||
if node_input_name == input_name:
|
||||
input_tensor_ranks[index] = new_rank
|
||||
input_tensor_dtypes[index] = new_dtype
|
||||
|
||||
node.attribute.remove(rank_attr)
|
||||
node.attribute.remove(dtype_attr)
|
||||
node.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks))
|
||||
node.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes))
|
||||
|
|
@ -9,7 +9,7 @@ from onnxruntime.training.utils.torch_io_helper import (
|
|||
extract_data_and_schema,
|
||||
unflatten_data_using_schema,
|
||||
)
|
||||
from onnxruntime.training.utils.torch_type_map import pytorch_dtype_to_onnx
|
||||
from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx
|
||||
|
||||
__all__ = [
|
||||
"PrimitiveType",
|
||||
|
|
@ -18,4 +18,5 @@ __all__ = [
|
|||
"extract_data_and_schema",
|
||||
"unflatten_data_using_schema",
|
||||
"pytorch_dtype_to_onnx",
|
||||
"onnx_dtype_to_pytorch",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
|
@ -178,87 +179,97 @@ class StatisticsSubscriber(SubscriberBase):
|
|||
order_file_path = step_path / "order.txt"
|
||||
tensor_file_path = step_path / output_file_name
|
||||
|
||||
# This is to try the best effort to align the count of numbers per line for easier comparison in diff views,
|
||||
# though it does not always guarantee to do this way.
|
||||
torch.set_printoptions(precision=6, linewidth=128)
|
||||
|
||||
tensor_shape = tensor.shape
|
||||
tensor_dtype = tensor.dtype
|
||||
flatten_array = tensor.flatten().view(-1)
|
||||
|
||||
if self._run_on_cpu:
|
||||
flatten_array = flatten_array.to("cpu")
|
||||
|
||||
if self._run_on_cpu:
|
||||
num_nan = torch.isnan(flatten_array).sum()
|
||||
num_inf = torch.isinf(flatten_array).sum()
|
||||
num_neg = (flatten_array < 0).sum()
|
||||
num_pos = (flatten_array > 0).sum()
|
||||
num_zero = (flatten_array == 0).sum()
|
||||
min_value = flatten_array.min()
|
||||
max_value = flatten_array.max()
|
||||
mean_value = flatten_array.mean()
|
||||
std_value = flatten_array.std()
|
||||
else:
|
||||
# Split the calculation for each bucket, then do another round of calculation on the bucket results.
|
||||
# This can at the best effort reduce the peak memory impact.
|
||||
bucket_size = self._bucket_size
|
||||
element_count = flatten_array.numel()
|
||||
ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size)
|
||||
nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
|
||||
# Summary for each bucket
|
||||
element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
for i in range(ceil_bucket_count):
|
||||
end = min((i + 1) * bucket_size, element_count)
|
||||
bucket = flatten_array[i * bucket_size : end]
|
||||
element_count_per_bucket[i] = bucket.numel()
|
||||
|
||||
nan_buckets[i] = torch.isnan(bucket).sum()
|
||||
inf_buckets[i] = torch.isinf(bucket).sum()
|
||||
neg_buckets[i] = (bucket < 0).sum()
|
||||
pos_buckets[i] = (bucket > 0).sum()
|
||||
zero_buckets[i] = (bucket == 0).sum()
|
||||
min_buckets[i] = bucket.min()
|
||||
max_buckets[i] = bucket.max()
|
||||
mean_buckets[i] = bucket.sum()
|
||||
std_buckets[i] = bucket.std()
|
||||
|
||||
# Reduction across all buckets
|
||||
num_nan = nan_buckets.sum()
|
||||
num_inf = inf_buckets.sum()
|
||||
num_neg = neg_buckets.sum()
|
||||
num_pos = pos_buckets.sum()
|
||||
num_zero = zero_buckets.sum()
|
||||
min_value = min_buckets.min()
|
||||
max_value = max_buckets.max()
|
||||
mean_value = float(mean_buckets.sum()) / float(element_count)
|
||||
# Here we refer to
|
||||
# https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
|
||||
# to calculate the combined standard deviation of all buckets.
|
||||
s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * (
|
||||
(mean_buckets - mean_value) ** 2
|
||||
)
|
||||
std_value = torch.sqrt(s.sum() / (element_count - 1))
|
||||
|
||||
with order_file_path.open(mode="a", encoding="utf-8") as f:
|
||||
f.write(f"{output_file_name}\n")
|
||||
|
||||
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n"
|
||||
f"min: {min_value} max: {max_value}, mean: {mean_value}, "
|
||||
f"std: {std_value} \n"
|
||||
f"nan: {num_nan}, inf: {num_inf}\n"
|
||||
)
|
||||
f.write(f"samples(top 128): {flatten_array[:128]}\n")
|
||||
f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n")
|
||||
f.write(f"{'='*16}\n")
|
||||
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
|
||||
|
||||
|
||||
def _summarize_tensor(
|
||||
display_name: str,
|
||||
tensor: torch.Tensor,
|
||||
f: TextIOWrapper,
|
||||
depth: int = 0,
|
||||
run_on_cpu: bool = False,
|
||||
bucket_size: int = 1024 * 1024 * 1024 // 2,
|
||||
):
|
||||
# This is to try the best effort to align the count of numbers per line for easier comparison in diff views,
|
||||
# though it does not always guarantee to do this way.
|
||||
torch.set_printoptions(precision=6, linewidth=128)
|
||||
|
||||
tensor_shape = tensor.shape
|
||||
tensor_dtype = tensor.dtype
|
||||
flatten_array = tensor.flatten().view(-1)
|
||||
|
||||
if run_on_cpu:
|
||||
flatten_array = flatten_array.to("cpu")
|
||||
|
||||
if run_on_cpu:
|
||||
num_nan = torch.isnan(flatten_array).sum()
|
||||
num_inf = torch.isinf(flatten_array).sum()
|
||||
num_neg = (flatten_array < 0).sum()
|
||||
num_pos = (flatten_array > 0).sum()
|
||||
num_zero = (flatten_array == 0).sum()
|
||||
min_value = flatten_array.min()
|
||||
max_value = flatten_array.max()
|
||||
mean_value = flatten_array.mean()
|
||||
std_value = flatten_array.std()
|
||||
else:
|
||||
# Split the calculation for each bucket, then do another round of calculation on the bucket results.
|
||||
# This can at the best effort reduce the peak memory impact.
|
||||
element_count = flatten_array.numel()
|
||||
ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size)
|
||||
nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
|
||||
# Summary for each bucket
|
||||
element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
for i in range(ceil_bucket_count):
|
||||
end = min((i + 1) * bucket_size, element_count)
|
||||
bucket = flatten_array[i * bucket_size : end]
|
||||
element_count_per_bucket[i] = bucket.numel()
|
||||
|
||||
nan_buckets[i] = torch.isnan(bucket).sum()
|
||||
inf_buckets[i] = torch.isinf(bucket).sum()
|
||||
neg_buckets[i] = (bucket < 0).sum()
|
||||
pos_buckets[i] = (bucket > 0).sum()
|
||||
zero_buckets[i] = (bucket == 0).sum()
|
||||
min_buckets[i] = bucket.min()
|
||||
max_buckets[i] = bucket.max()
|
||||
mean_buckets[i] = bucket.sum()
|
||||
std_buckets[i] = bucket.std()
|
||||
|
||||
# Reduction across all buckets
|
||||
num_nan = nan_buckets.sum()
|
||||
num_inf = inf_buckets.sum()
|
||||
num_neg = neg_buckets.sum()
|
||||
num_pos = pos_buckets.sum()
|
||||
num_zero = zero_buckets.sum()
|
||||
min_value = min_buckets.min()
|
||||
max_value = max_buckets.max()
|
||||
mean_value = float(mean_buckets.sum()) / float(element_count)
|
||||
# Here we refer to
|
||||
# https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
|
||||
# to calculate the combined standard deviation of all buckets.
|
||||
s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * (
|
||||
(mean_buckets - mean_value) ** 2
|
||||
)
|
||||
std_value = torch.sqrt(s.sum() / (element_count - 1))
|
||||
|
||||
f.write(
|
||||
f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n"
|
||||
f"min: {min_value} max: {max_value}, mean: {mean_value}, "
|
||||
f"std: {std_value} \n"
|
||||
f"nan: {num_nan}, inf: {num_inf}\n"
|
||||
)
|
||||
f.write(f"samples(top 128): {flatten_array[:128]}\n")
|
||||
f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n")
|
||||
f.write(f"{'='*16}\n")
|
||||
|
|
|
|||
|
|
@ -29,14 +29,6 @@ def no_increase_global_step():
|
|||
finally:
|
||||
ORT_NO_INCREASE_GLOBAL_STEP[0] = False
|
||||
|
||||
@staticmethod
|
||||
def infer_shape(
|
||||
node: onnx.NodeProto,
|
||||
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
|
||||
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
|
||||
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
|
||||
return tensor_input_shapes, tensor_input_dtypes
|
||||
|
||||
|
||||
class _IncrementStep(torch.autograd.Function):
|
||||
"""This class is used to manage the global execution step, e.g.
|
||||
|
|
@ -55,8 +47,9 @@ class _IncrementStep(torch.autograd.Function):
|
|||
ctx.current_step = run_ctx.global_states.execution_step
|
||||
ctx.run_ctx = run_ctx
|
||||
|
||||
if ctx.current_step >= 0:
|
||||
print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}")
|
||||
# Uncomment the following line for debugging purposes.
|
||||
# if ctx.current_step >= 0:
|
||||
# print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}")
|
||||
|
||||
if ORT_NO_INCREASE_GLOBAL_STEP[0] is False:
|
||||
ctx.run_ctx.global_states.execution_step += 1
|
||||
|
|
@ -191,7 +184,7 @@ class SubscriberManager:
|
|||
next_module_index: list of int, carrying a global unique module index that can be used next.
|
||||
"""
|
||||
module_index = next_module_index[0]
|
||||
module.id = module_index # STAGE3WARN: needed by DeepSpeed
|
||||
module.id = module_index # STAGE3WARN#1: needed by DeepSpeed
|
||||
self._run_ctx.global_states.module_index_to_depth[module_index] = depth
|
||||
self._run_ctx.global_states.module_to_module_index[module] = module_index
|
||||
|
||||
|
|
@ -217,7 +210,7 @@ class SubscriberManager:
|
|||
next_module_index: list of int, carrying a global unique module index that can be used next.
|
||||
"""
|
||||
module_index = next_module_index[0]
|
||||
module.id = module_index # STAGE3WARN: needed by DeepSpeed
|
||||
module.id = module_index # STAGE3WARN#2: needed by DeepSpeed
|
||||
self._run_ctx.global_states.module_index_to_depth[module_index] = depth
|
||||
self._run_ctx.global_states.module_to_module_index[module] = module_index
|
||||
|
||||
|
|
|
|||
|
|
@ -23,25 +23,37 @@ from onnxruntime.training.utils import (
|
|||
from ._subscriber_base import RuntimeStates, SubscriberBase
|
||||
|
||||
|
||||
# Used to monkey patch the original function
|
||||
# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333
|
||||
def _setup_zero_stage3_ort_compatible_hooks(self):
|
||||
self.hierarchy = 0
|
||||
def _get_ort_compatible_zero_stage3_hook_function(debug, stats_output_dir, stats_overwrite):
|
||||
"""Create ort compatible hook function for DeepSpeed ZeRO stage3.
|
||||
|
||||
from onnxruntime.training.utils.hooks import SubscriberManager, ZeROOffloadSubscriber
|
||||
from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer
|
||||
Args:
|
||||
debug: whether to enable convergence debugging.
|
||||
stats_output_dir: the directory to store convergence stats.
|
||||
stats_overwrite: whether to overwrite the stats file if it already exists.
|
||||
"""
|
||||
|
||||
# Each DeepSpeed engine has a separate subscriber manager.
|
||||
self._offload_subscriber_manager = SubscriberManager()
|
||||
self._offload_subscriber_manager.subscribe(
|
||||
self.module, [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)]
|
||||
)
|
||||
self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks)
|
||||
self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks)
|
||||
# Used to monkey patch the original function
|
||||
# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333
|
||||
def _setup_zero_stage3_ort_compatible_hooks(self):
|
||||
self.hierarchy = 0
|
||||
|
||||
# Add top module to stack trace
|
||||
global FWD_MODULE_STACK # noqa: PLW0602
|
||||
FWD_MODULE_STACK.append(self.module)
|
||||
from onnxruntime.training.utils.hooks import StatisticsSubscriber, SubscriberManager, ZeROOffloadSubscriber
|
||||
from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer
|
||||
|
||||
subscribers = [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)]
|
||||
if debug is True:
|
||||
subscribers.append(StatisticsSubscriber(output_dir=stats_output_dir, override_output_dir=stats_overwrite))
|
||||
# Each DeepSpeed engine has a separate subscriber manager.
|
||||
self._offload_subscriber_manager = SubscriberManager()
|
||||
self._offload_subscriber_manager.subscribe(self.module, subscribers)
|
||||
self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks)
|
||||
self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks)
|
||||
|
||||
# Add top module to stack trace
|
||||
global FWD_MODULE_STACK # noqa: PLW0602
|
||||
FWD_MODULE_STACK.append(self.module)
|
||||
|
||||
return _setup_zero_stage3_ort_compatible_hooks
|
||||
|
||||
|
||||
# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104
|
||||
|
|
@ -86,14 +98,16 @@ try:
|
|||
_zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload.setup_zero_stage3_hooks)
|
||||
|
||||
# This is the function to enable ORT ZeRO offload.
|
||||
def configure_ort_compatible_zero_stage3():
|
||||
def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", stats_overwrite=False):
|
||||
"""Configure ZeRO stage3 to be ORT compatible.
|
||||
|
||||
This function will overwrite the original DeepSpeed ZeRO stage3 hooks to make it ORT compatible.
|
||||
"""
|
||||
|
||||
# Only done once no matter how many times this function is called for different modules.
|
||||
DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _setup_zero_stage3_ort_compatible_hooks
|
||||
DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _get_ort_compatible_zero_stage3_hook_function(
|
||||
debug, stats_output_dir, stats_overwrite
|
||||
)
|
||||
|
||||
from deepspeed.runtime.zero.linear import zero3_linear_wrap
|
||||
|
||||
|
|
@ -103,7 +117,7 @@ try:
|
|||
except ImportError as e:
|
||||
warnings.warn(f"DeepSpeed import error {e}")
|
||||
|
||||
def configure_ort_compatible_zero_stage3():
|
||||
def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False):
|
||||
raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.")
|
||||
|
||||
|
||||
|
|
@ -115,13 +129,13 @@ def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.par
|
|||
"""
|
||||
from deepspeed.runtime.zero.partitioned_param_coordinator import iter_params
|
||||
|
||||
# Retrive the parameters that are not available for this module.
|
||||
# Retrieve all parameters for this module.
|
||||
partitioned_params = [param for param in iter_params(module)]
|
||||
|
||||
return partitioned_params
|
||||
|
||||
|
||||
def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]:
|
||||
def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]:
|
||||
"""Retrieve all the parameters that are offloaded."""
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
|
||||
|
|
@ -134,16 +148,13 @@ def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.par
|
|||
|
||||
|
||||
class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
|
||||
"""This function is a common bridge to call original PyTorch's
|
||||
pre_forward_function and post_backward_function.
|
||||
"""
|
||||
"""This function is a common bridge to call original PyTorch's pre_forward_function"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
module,
|
||||
pre_forward_with_kwargs_function,
|
||||
post_backward_function,
|
||||
args_schema,
|
||||
kwargs_schema,
|
||||
args_tensor_count,
|
||||
|
|
@ -155,7 +166,6 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
|
|||
ctx: context object
|
||||
module: the module to be called
|
||||
pre_forward_with_kwargs_function: the function to be called before forward (PyTorch's pre_forward_function)
|
||||
post_backward_function: the function to be called after backward (PyTorch's post_backward_function)
|
||||
args_schema: the schema of the args, used to reconstruct the args in original form in
|
||||
PyTorch's pre_forward_function's inputs.
|
||||
kwargs_schema: the schema of the kwargs, used to reconstruct the kwargs in original form in
|
||||
|
|
@ -168,6 +178,17 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
|
|||
args_tensors = tensor_list[:args_tensor_count]
|
||||
kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count]
|
||||
|
||||
# For PyTorch runs, the sizes are all 0, it does not need a gradient because
|
||||
# param._detach().requires_grad_(False) is called.
|
||||
# But for ORT runs, the sizes are all [1], as output of weight retrieval function.
|
||||
# So we keep track of the shapes and dtypes of the passed-in tensors, then generate the grads in backward.
|
||||
# While for both PyTorch and ORT runs, the grad is not important because they are not param grads
|
||||
# anymore, they are only used for completing the full backward propagation.
|
||||
passed_in_param_tensors = tensor_list[args_tensor_count + kwargs_tensor_count :]
|
||||
ctx.shapes = [p.shape for p in passed_in_param_tensors]
|
||||
ctx.dtypes = [p.dtype for p in passed_in_param_tensors]
|
||||
ctx.devices = [p.device for p in passed_in_param_tensors]
|
||||
|
||||
args = unflatten_data_using_schema(args_tensors, args_schema)
|
||||
kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema)
|
||||
|
||||
|
|
@ -179,6 +200,8 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
|
|||
partitioned_params = _get_params_for_current_module(module)
|
||||
ctx.partitioned_params = partitioned_params
|
||||
|
||||
assert len(partitioned_params) == len(passed_in_param_tensors)
|
||||
|
||||
f_ret = pre_forward_with_kwargs_function(module, args, kwargs)
|
||||
|
||||
if f_ret is None:
|
||||
|
|
@ -188,7 +211,6 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
|
|||
updated_args, updated_kwargs = f_ret
|
||||
|
||||
ctx.module = module
|
||||
ctx.post_backward_function = post_backward_function
|
||||
|
||||
updated_args_tensors, _ = extract_data_and_schema(updated_args)
|
||||
updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs)
|
||||
|
|
@ -203,17 +225,32 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
updated_grads = grads
|
||||
if ctx.post_backward_function is not None:
|
||||
ret = ctx.post_backward_function(ctx.module, grads)
|
||||
if ret is not None:
|
||||
updated_grads = ret
|
||||
|
||||
# TODO(pengwa) Update grad for partitioned parameters.
|
||||
input_count = len(updated_grads) - len(ctx.partitioned_params)
|
||||
zeros = [torch.zeros(0, dtype=p.dtype, device=p.device) for p in ctx.partitioned_params]
|
||||
zero_grads = updated_grads[:input_count] + tuple(zeros)
|
||||
param_start_offset = input_count
|
||||
|
||||
return (None, None, None, None, None, None, None, *zero_grads)
|
||||
# Only need to accumulate grad explicitly for ORT run (e.g. ctx.shapes[0] == (1,));
|
||||
# In the PyTorch run, the accumulation happens automatically.
|
||||
need_manual_grad_acc = len(ctx.shapes) > 0 and ctx.shapes[0] == (1,)
|
||||
if need_manual_grad_acc:
|
||||
for param_index, p in enumerate(ctx.partitioned_params):
|
||||
g = updated_grads[param_index + param_start_offset]
|
||||
if g is None:
|
||||
raise RuntimeError(f"param {p} has no grad, this should not happen.")
|
||||
# Param gradient accumulation is triggered here, along with the attached hooks, done by PyTorch.
|
||||
assert p.shape == g.shape, f"param_index: {param_index} - param shape {p.shape} != grad shape {g.shape}"
|
||||
p.backward(g)
|
||||
|
||||
# At this point, the **real** param grads are already updated, the following grads are only used for
|
||||
# completing the full backward propagation, will not affect parameter updates.
|
||||
passed_in_param_grad = [
|
||||
torch.zeros(shape, dtype=dtype, device=device)
|
||||
for shape, dtype, device in zip(ctx.shapes, ctx.dtypes, ctx.devices)
|
||||
]
|
||||
|
||||
zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad)
|
||||
|
||||
return (None, None, None, None, None, None, *zero_grads)
|
||||
|
||||
@staticmethod
|
||||
def infer_shape(
|
||||
|
|
@ -258,14 +295,14 @@ class ORTZeROOffloadPostForwardFunction(torch.autograd.Function):
|
|||
module: the module to be called
|
||||
post_forward_function: the function to be called after forward (PyTorch's post_forward_function)
|
||||
pre_backward_function: the function to be called before backward (PyTorch's pre_backward_function)
|
||||
output_schema: the schema of the output, used to reconstruct the output in original form in
|
||||
output_schema: the schema of the output, used to reconstruct the output in its original form in
|
||||
PyTorch's post_forward_function's inputs.
|
||||
output_tensors: the list of tensors.
|
||||
|
||||
"""
|
||||
outputs = unflatten_data_using_schema(output_tensors, output_schema)
|
||||
|
||||
# STAGE3WARN: _post_forward_module_hook's second argument `input is not used, so we just pass a None here.
|
||||
# STAGE3WARN#3: _post_forward_module_hook's second argument `input is not used, so we just pass a None here.
|
||||
updated_outputs = post_forward_function(module, None, outputs)
|
||||
|
||||
if updated_outputs is None:
|
||||
|
|
@ -341,12 +378,20 @@ class ZeROOffloadSubscriber(SubscriberBase):
|
|||
input and output for torch.autograd.Function, so we do flatten and unflatten here.
|
||||
|
||||
"""
|
||||
## Handle `_post_backward_module_hook`
|
||||
|
||||
args_tensors, args_schema = extract_data_and_schema(args)
|
||||
# Put `_post_backward_module_hook` first because in backward, it is responsible for unloading parameters,
|
||||
# we want ORTZeROOffloadPreForwardFunction's backward still be able to access the full sized parameters.
|
||||
_post_backward_module_hook = self._functions.get("_post_backward_module_hook")
|
||||
# STAGE3WARN#4: most logic in _post_backward_module_hook can be traced correctly so we don't need to
|
||||
# wrap with PythonOp. For those cannot be traced, we handle them in STAGE3WARN#5.
|
||||
updated_args = _post_backward_module_hook(module, args)
|
||||
|
||||
## Handle `_pre_forward_module_hook`
|
||||
|
||||
args_tensors, args_schema = extract_data_and_schema(updated_args)
|
||||
kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs)
|
||||
|
||||
partitioned_params = _get_params_for_current_module(module)
|
||||
|
||||
_pre_forward_module_hook = self._functions.get("_pre_forward_module_hook")
|
||||
|
||||
args_tensor_count = len(args_tensors)
|
||||
|
|
@ -358,18 +403,29 @@ class ZeROOffloadSubscriber(SubscriberBase):
|
|||
if rets is not None:
|
||||
updated_args = rets
|
||||
|
||||
# STAGE3WARN: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration.
|
||||
# STAGE3WARN#5: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration.
|
||||
module.ds_grads_remaining = 0
|
||||
|
||||
return updated_args, updated_kwargs
|
||||
|
||||
all_tensors = args_tensors + kwargs_tensors + partitioned_params
|
||||
# Need to pass the parameters as input to let the exporter trace the related weights for
|
||||
# current ORTZeROOffloadPreForwardFunction
|
||||
partitioned_params = _get_params_for_current_module(module)
|
||||
# Don't require grad for passed-in parameter, otherwise it will be treated as a leaf node, in backward
|
||||
# returned 0-sized grad did not match the param's gradient accumulator function's input shape metadata,
|
||||
# PyTorch run will fail during backward.
|
||||
# This will not harm parameter gradient build either in ORT or PyTorch, imagine the weights are used by
|
||||
# computation anyway, so the gradient will be built. This hook only references the parameter, but won't
|
||||
# generate a gradient path for it.
|
||||
detached_partitioned_params = [p.detach().requires_grad_(False) for p in partitioned_params]
|
||||
|
||||
all_tensors = args_tensors + kwargs_tensors + detached_partitioned_params
|
||||
|
||||
self._check_all_tensor(all_tensors, module, "pre_forward_module_apply_impl input check")
|
||||
|
||||
rets = ORTZeROOffloadPreForwardFunction.apply(
|
||||
module,
|
||||
_wrap_pre_forward_module_hook,
|
||||
None,
|
||||
args_schema,
|
||||
kwargs_schema,
|
||||
args_tensor_count,
|
||||
|
|
@ -385,11 +441,6 @@ class ZeROOffloadSubscriber(SubscriberBase):
|
|||
updated_args = unflatten_data_using_schema(updated_args_tensors, args_schema)
|
||||
updated_kwargs = unflatten_data_using_schema(updated_kwargs_tensors, kwargs_schema)
|
||||
|
||||
_post_backward_module_hook = self._functions.get("_post_backward_module_hook")
|
||||
# STAGE3WARN: Other part of _post_backward_module_hook can be traced correctly so we don't need to
|
||||
# wrap with PythonOp.
|
||||
updated_args = _post_backward_module_hook(module, updated_args)
|
||||
|
||||
return updated_args, updated_kwargs
|
||||
|
||||
def post_forward_module_apply_impl(
|
||||
|
|
@ -411,7 +462,7 @@ class ZeROOffloadSubscriber(SubscriberBase):
|
|||
_post_forward_module_hook = self._functions.get("_post_forward_module_hook")
|
||||
|
||||
def _wrap_post_forward_module_hook(module, input, outputs):
|
||||
# STAGE3WARN: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here.
|
||||
# STAGE3WARN#6: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here.
|
||||
from deepspeed.runtime.zero.partition_parameters import is_zero_param
|
||||
|
||||
updated_outputs = _post_forward_module_hook(module, input, outputs)
|
||||
|
|
@ -438,8 +489,8 @@ class ZeROOffloadSubscriber(SubscriberBase):
|
|||
updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema)
|
||||
|
||||
_pre_backward_module_hook = self._functions.get("_pre_backward_module_hook")
|
||||
# STAGE3WARN: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here.
|
||||
# STAGE3WARN: part of the original _pre_backward_module_hook can be traced correctly so we moved them into
|
||||
# STAGE3WARN#7: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here.
|
||||
# STAGE3WARN#8: part of the original _pre_backward_module_hook can be traced correctly so we moved them into
|
||||
# _wrap_post_forward_module_hook above.
|
||||
updated_outputs = _pre_backward_module_hook(module, None, updated_outputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ _CAST_PYTORCH_TO_ONNX = {
|
|||
|
||||
_DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _CAST_PYTORCH_TO_ONNX.items()}
|
||||
|
||||
_ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()}
|
||||
|
||||
|
||||
def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType:
|
||||
"""Converts a pytorch dtype or scalar type string to an onnx dtype."""
|
||||
|
|
@ -45,3 +47,10 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc
|
|||
if dtype not in _DTYPE_TO_ONNX:
|
||||
raise RuntimeError(f"Unsupported dtype {dtype}")
|
||||
return _DTYPE_TO_ONNX[dtype]
|
||||
|
||||
|
||||
def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype:
|
||||
"""Converts an onnx dtype to a pytorch dtype."""
|
||||
if dtype not in _ONNX_TO_DTYPE:
|
||||
raise RuntimeError(f"Unsupported dtype {dtype}")
|
||||
return _ONNX_TO_DTYPE[dtype]
|
||||
|
|
|
|||
|
|
@ -153,8 +153,11 @@ void PythonOpBase::RunForward(OpKernelContext* context,
|
|||
inplace_ != 0,
|
||||
kernel_invoke_id_);
|
||||
|
||||
ORT_ENFORCE(1 + returned_ortvalues.size() == static_cast<size_t>(context->OutputCount()),
|
||||
"Output count mismatch for PythonOp run");
|
||||
const size_t returned_output_count = 1 + returned_ortvalues.size();
|
||||
const size_t kernel_output_count = static_cast<size_t>(context->OutputCount());
|
||||
ORT_ENFORCE(returned_output_count == kernel_output_count, "Output count mismatch for PythonOp run, ",
|
||||
"returned_output_count: ", returned_output_count, ", expected kernel_output_count: ",
|
||||
kernel_output_count);
|
||||
}
|
||||
|
||||
void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vector<OrtValue>& returned_args) const {
|
||||
|
|
|
|||
Loading…
Reference in a new issue