diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 7ea015030c..f1ae93cfc1 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2385,31 +2385,35 @@ class SymbolicShapeInference: output_tensor_ranks = get_attribute(node, "output_tensor_ranks") assert output_tensor_ranks - # set the context output separately. - # The first output is autograd's context. + from onnxruntime.training.ortmodule._custom_autograd_function_exporter import PythonOpShapeInferStore + + func_name = get_attribute(node, "func_name").decode() + shape_inferer = PythonOpShapeInferStore.get_shape_infer(func_name) + + # Set the context output separately. + # The first output is torch.autograd.Function''s context. vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) - # TODO(pengwa): allow custom PythonOp shape inference. - if get_attribute(node, "func_name").decode() in [ - "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation", - "onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep", - ]: - # PythonOp with func_name being "_InspectActivation" or "_IncrementStep" will behave exactly same as a normal - # PythonOp when execution. The only difference is that - # 1). those ops having same number of tensor inputs and tensor outputs; - # 2). and the i-th output tensor's shape is same as i-th input tensor's shape. - # Be noted, the count of custom autograd function might be bigger than output count, because there might - # be other non-tensor constant inputs (string, object, int, tuple, etc). But we did not make those constant - # inputs as ONNX op's input, instead they are stored in the attributes. - assert len(node.output) == len(node.input) + 1 # The output contains one extra context info. - for input_index in range(len(node.output) - 1): - # Process the i-th tensor outputs. - vi = self.known_vi_[node.output[input_index + 1]] + + if shape_inferer is not None: + input_shapes = [] + input_dtypes = [] + for input_index in range(len(node.input)): shape = self._get_shape(node, input_index) - output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type - vi.CopyFrom(helper.make_tensor_value_info(node.output[input_index + 1], output_dtype, shape)) + input_shapes.append(shape) + input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type + input_dtypes.append(input_dtype) + output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) + assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1) + for i in range(len(node.output) - 1): + output_index = i + 1 + vi = self.known_vi_[node.output[output_index]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i]) + ) else: - # Outputs after autograd's context are tensors. + # General shape inference for PythonOp. + # Outputs after torch.autograd.Function's context are tensors. # We assume their ranks are fixed for different model inputs. for i in range(len(node.output) - 1): # Process the i-th tensor outputs. diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 78ac2ffcc5..2d05c04a7b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import sys +from typing import Callable, ClassVar, Dict, Optional import onnx import torch @@ -18,8 +19,44 @@ from ._custom_op_symbolic_registry import pytorch_type_to_onnx, wrap_custom_expo from ._fallback import ORTModuleONNXModelException, wrap_exception from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version + +class PythonOpShapeInferStore: + """A class to store shape inference functions for torch.autograd.Function.""" + + _CLASS_MAP: ClassVar[Dict[str, Callable]] = {} + + @classmethod + def register(cls, kclass: torch.autograd.Function) -> None: + """Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined. + + The signature of the shape inference function should be: + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_output_shapes = [] + tensor_output_dtypes = [] + ... + return tensor_output_shapes, tensor_output_dtypes + + The tensor_input_shapes and tensor_input_dtypes are lists of shapes and dtypes of the input tensors. + The tensor_output_shapes and tensor_output_dtypes are lists of shapes and dtypes of the output tensors. + Be noted: we only pass in tensor inputs, and return tensor outputs, non-tensor inputs/outputs are ignored. + + """ + kclass_name = get_fully_qualified_class_name(kclass) + if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP: + cls._CLASS_MAP[kclass_name] = kclass.infer_shape + + @classmethod + def get_shape_infer(cls, name: str) -> Optional[Callable]: + return cls._CLASS_MAP.get(name, None) + + """ -Defines a list of names of torch.torch.autograd.Function, for checkpoint activation purposes. +Defines a list of names of torch.autograd.Function, for checkpoint activation purposes. Note: If CheckpointFunction is exported as PythonOp, the checkpoint-ed computation @@ -220,6 +257,9 @@ def _export_pt_1_10(g, n, *args, **kwargs): # Register function with class names. register_torch_autograd_function(func_full_qual_name, func_class) + + # Register shape inference function. + PythonOpShapeInferStore.register(func_class) return returned_args except Exception as e: sys.stdout.flush() @@ -235,7 +275,7 @@ def _post_process_after_export( ) -> onnx.ModelProto: """Post process the exported model.""" if enable_custom_autograd_function: - return _post_process_enabling_autograd_function(exported_model) + exported_model = _post_process_enabling_autograd_function(exported_model) return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 5ba3769da4..f8d3a6e779 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -158,6 +158,8 @@ class DebugOptions: return [ # [W shape_type_inference.cpp:1974] Warning: The shape inference of com.microsoft::PythonOp type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable) "type is missing, so it may result in wrong shape inference", + # diagnostics [WARNING] - None + "[WARNING] - None", ] return None diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index 82e560e1fd..72208b7228 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -4,8 +4,9 @@ # -------------------------------------------------------------------------- from collections import abc -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union +import onnx import torch from onnxruntime.training.ortmodule import ORTModule @@ -111,6 +112,14 @@ class _InspectActivation(torch.autograd.Function): return None, None, None, grad_output.detach() if grad_output is not None else None + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + class _IncrementStep(torch.autograd.Function): """ @@ -163,6 +172,14 @@ class _IncrementStep(torch.autograd.Function): return None, grad_output.detach() if isinstance(grad_output, torch.Tensor) else grad_output + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + class SubscriberManager: """ diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 12c1cda772..5f33854836 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -5,6 +5,7 @@ # pylint: disable=C0103 # pylint: disable=W0212 +import onnx import pytest import torch @@ -1380,3 +1381,110 @@ def test_duplicate_named_functions(): assert triggered[0] assert triggered[1] + + +def test_customized_shape_inference(): + def _check_pythonop_shape(model): + graph = model._torch_module._execution_manager._training_manager._onnx_models.optimized_model.graph + found_pythonop = False + python_op_input = [] + python_op_output = [] + for node in graph.node: + if node.op_type == "PythonOp": + found_pythonop = True + python_op_input = node.input + python_op_output = node.output + break + + assert found_pythonop, "PythonOp should be found in the optimized_model" + + input_shapes = [None] + input_dtypes = [None] + + output_shapes = [None, None] + output_dtypes = [None, None] + + def _find_shape_and_dtype(value_infos): + for value_info in value_infos: + if value_info.name == python_op_input[0]: + input_shapes[0] = value_info.type.tensor_type.shape + input_dtypes[0] = value_info.type.tensor_type.elem_type + + if value_info.name == python_op_output[0]: + output_shapes[0] = value_info.type.tensor_type.shape + output_dtypes[0] = value_info.type.tensor_type.elem_type + + if value_info.name == python_op_output[1]: + output_shapes[1] = value_info.type.tensor_type.shape + output_dtypes[1] = value_info.type.tensor_type.elem_type + + _find_shape_and_dtype(graph.input) + _find_shape_and_dtype(graph.value_info) + + assert all(s is not None for s in input_shapes), "PythonOp input shape should be found in the optimized_model" + assert ( + all(d is not None for d in input_dtypes) is not None + ), "PythonOp input dtype should be found in the optimized_model" + + assert all(s is not None for s in output_shapes), "PythonOp output shape should be found in the optimized_model" + assert ( + all(d is not None for d in output_dtypes) is not None + ), "PythonOp output dtype should be found in the optimized_model" + + def _compare_shape(shape1, shape2): + if len(shape1.dim) != len(shape2.dim): + return False + + for dim1, dim2 in zip(shape1.dim, shape2.dim): + if dim1.HasField("dim_value") and dim1.HasField("dim_value") and dim1.dim_value == dim2.dim_value: + continue + + if dim1.HasField("dim_param") and dim1.HasField("dim_param") and dim1.dim_param == dim2.dim_param: + continue + + return False + + return True + + assert output_dtypes[0] == onnx.TensorProto.INT64 + assert len(output_shapes[0].dim) == 0 + assert _compare_shape(input_shapes[0], output_shapes[1]) + assert input_dtypes[0] == output_dtypes[1] + + class CustomShapeInferFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * grad_output + + @staticmethod + def infer_shape(node: onnx.NodeProto, tensor_input_shapes, tensor_input_dtypes): + return [tensor_input_shapes[0]], [tensor_input_dtypes[0]] + + class TestModel(torch.nn.Module): + def __init__(self, output_size): + super().__init__() + self.custom_fn = CustomShapeInferFunction.apply + self.bias = Parameter(torch.empty(output_size, dtype=torch.float)) + + with torch.no_grad(): + self.bias.uniform_() + + def forward(self, model_input): + # model_input did not require_grad + out = self.custom_fn(model_input) + return out + self.bias + + output_size = 1024 + ortmodule = ORTModule( + TestModel(output_size), + ).train() + _ = ortmodule(torch.randn(output_size, dtype=torch.float)) + _check_pythonop_shape(ortmodule)