Allow defining customized PythonOp shape inferer (#17093)

### Allow defining customized PythonOp shape inferer

For `torch.autograd.Function`, we converted it to PythonOp in MSDomain,
there are two places to do shape inferencing for it:

1. in SymbolicShapeInfer, there is one. 
2. in PythonOp op definition. 

For common PythonOp, since we don't know the relation ship between
inputs and outputs, so we only infer the rank from output ranks, and
generate symbolic dimensions for each dim. While this will introduce
many meaningless symbolic dimensions, sometimes blocking our graph
transformers to do op fusion.

This PR provide a way to define custom shape inferencing for
`torch.autograd.Function` we defined, to propagate the original
dimensions across the PythonOp at the best efforts.

But the 2rd one is not covered yet, we could refine that later. Fixing
1st one is enough for ORTModule training/evaluation.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2023-08-14 09:13:32 +08:00 committed by GitHub
parent 9204cd7392
commit cd7b3f54da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 195 additions and 24 deletions

View file

@ -2385,31 +2385,35 @@ class SymbolicShapeInference:
output_tensor_ranks = get_attribute(node, "output_tensor_ranks")
assert output_tensor_ranks
# set the context output separately.
# The first output is autograd's context.
from onnxruntime.training.ortmodule._custom_autograd_function_exporter import PythonOpShapeInferStore
func_name = get_attribute(node, "func_name").decode()
shape_inferer = PythonOpShapeInferStore.get_shape_infer(func_name)
# Set the context output separately.
# The first output is torch.autograd.Function''s context.
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []))
# TODO(pengwa): allow custom PythonOp shape inference.
if get_attribute(node, "func_name").decode() in [
"onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation",
"onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep",
]:
# PythonOp with func_name being "_InspectActivation" or "_IncrementStep" will behave exactly same as a normal
# PythonOp when execution. The only difference is that
# 1). those ops having same number of tensor inputs and tensor outputs;
# 2). and the i-th output tensor's shape is same as i-th input tensor's shape.
# Be noted, the count of custom autograd function might be bigger than output count, because there might
# be other non-tensor constant inputs (string, object, int, tuple, etc). But we did not make those constant
# inputs as ONNX op's input, instead they are stored in the attributes.
assert len(node.output) == len(node.input) + 1 # The output contains one extra context info.
for input_index in range(len(node.output) - 1):
# Process the i-th tensor outputs.
vi = self.known_vi_[node.output[input_index + 1]]
if shape_inferer is not None:
input_shapes = []
input_dtypes = []
for input_index in range(len(node.input)):
shape = self._get_shape(node, input_index)
output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
vi.CopyFrom(helper.make_tensor_value_info(node.output[input_index + 1], output_dtype, shape))
input_shapes.append(shape)
input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
input_dtypes.append(input_dtype)
output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes)
assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1)
for i in range(len(node.output) - 1):
output_index = i + 1
vi = self.known_vi_[node.output[output_index]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i])
)
else:
# Outputs after autograd's context are tensors.
# General shape inference for PythonOp.
# Outputs after torch.autograd.Function's context are tensors.
# We assume their ranks are fixed for different model inputs.
for i in range(len(node.output) - 1):
# Process the i-th tensor outputs.

View file

@ -4,6 +4,7 @@
# --------------------------------------------------------------------------
import sys
from typing import Callable, ClassVar, Dict, Optional
import onnx
import torch
@ -18,8 +19,44 @@ from ._custom_op_symbolic_registry import pytorch_type_to_onnx, wrap_custom_expo
from ._fallback import ORTModuleONNXModelException, wrap_exception
from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version
class PythonOpShapeInferStore:
"""A class to store shape inference functions for torch.autograd.Function."""
_CLASS_MAP: ClassVar[Dict[str, Callable]] = {}
@classmethod
def register(cls, kclass: torch.autograd.Function) -> None:
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined.
The signature of the shape inference function should be:
@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
tensor_output_shapes = []
tensor_output_dtypes = []
...
return tensor_output_shapes, tensor_output_dtypes
The tensor_input_shapes and tensor_input_dtypes are lists of shapes and dtypes of the input tensors.
The tensor_output_shapes and tensor_output_dtypes are lists of shapes and dtypes of the output tensors.
Be noted: we only pass in tensor inputs, and return tensor outputs, non-tensor inputs/outputs are ignored.
"""
kclass_name = get_fully_qualified_class_name(kclass)
if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP:
cls._CLASS_MAP[kclass_name] = kclass.infer_shape
@classmethod
def get_shape_infer(cls, name: str) -> Optional[Callable]:
return cls._CLASS_MAP.get(name, None)
"""
Defines a list of names of torch.torch.autograd.Function, for checkpoint activation purposes.
Defines a list of names of torch.autograd.Function, for checkpoint activation purposes.
Note:
If CheckpointFunction is exported as PythonOp, the checkpoint-ed computation
@ -220,6 +257,9 @@ def _export_pt_1_10(g, n, *args, **kwargs):
# Register function with class names.
register_torch_autograd_function(func_full_qual_name, func_class)
# Register shape inference function.
PythonOpShapeInferStore.register(func_class)
return returned_args
except Exception as e:
sys.stdout.flush()
@ -235,7 +275,7 @@ def _post_process_after_export(
) -> onnx.ModelProto:
"""Post process the exported model."""
if enable_custom_autograd_function:
return _post_process_enabling_autograd_function(exported_model)
exported_model = _post_process_enabling_autograd_function(exported_model)
return exported_model

View file

@ -158,6 +158,8 @@ class DebugOptions:
return [
# [W shape_type_inference.cpp:1974] Warning: The shape inference of com.microsoft::PythonOp type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
"type is missing, so it may result in wrong shape inference",
# diagnostics [WARNING] - None
"[WARNING] - None",
]
return None

View file

@ -4,8 +4,9 @@
# --------------------------------------------------------------------------
from collections import abc
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Tuple, Union
import onnx
import torch
from onnxruntime.training.ortmodule import ORTModule
@ -111,6 +112,14 @@ class _InspectActivation(torch.autograd.Function):
return None, None, None, grad_output.detach() if grad_output is not None else None
@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes
class _IncrementStep(torch.autograd.Function):
"""
@ -163,6 +172,14 @@ class _IncrementStep(torch.autograd.Function):
return None, grad_output.detach() if isinstance(grad_output, torch.Tensor) else grad_output
@staticmethod
def infer_shape(
node: onnx.NodeProto,
tensor_input_shapes: List[Optional[List[Union[int, str]]]],
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
return tensor_input_shapes, tensor_input_dtypes
class SubscriberManager:
"""

View file

@ -5,6 +5,7 @@
# pylint: disable=C0103
# pylint: disable=W0212
import onnx
import pytest
import torch
@ -1380,3 +1381,110 @@ def test_duplicate_named_functions():
assert triggered[0]
assert triggered[1]
def test_customized_shape_inference():
def _check_pythonop_shape(model):
graph = model._torch_module._execution_manager._training_manager._onnx_models.optimized_model.graph
found_pythonop = False
python_op_input = []
python_op_output = []
for node in graph.node:
if node.op_type == "PythonOp":
found_pythonop = True
python_op_input = node.input
python_op_output = node.output
break
assert found_pythonop, "PythonOp should be found in the optimized_model"
input_shapes = [None]
input_dtypes = [None]
output_shapes = [None, None]
output_dtypes = [None, None]
def _find_shape_and_dtype(value_infos):
for value_info in value_infos:
if value_info.name == python_op_input[0]:
input_shapes[0] = value_info.type.tensor_type.shape
input_dtypes[0] = value_info.type.tensor_type.elem_type
if value_info.name == python_op_output[0]:
output_shapes[0] = value_info.type.tensor_type.shape
output_dtypes[0] = value_info.type.tensor_type.elem_type
if value_info.name == python_op_output[1]:
output_shapes[1] = value_info.type.tensor_type.shape
output_dtypes[1] = value_info.type.tensor_type.elem_type
_find_shape_and_dtype(graph.input)
_find_shape_and_dtype(graph.value_info)
assert all(s is not None for s in input_shapes), "PythonOp input shape should be found in the optimized_model"
assert (
all(d is not None for d in input_dtypes) is not None
), "PythonOp input dtype should be found in the optimized_model"
assert all(s is not None for s in output_shapes), "PythonOp output shape should be found in the optimized_model"
assert (
all(d is not None for d in output_dtypes) is not None
), "PythonOp output dtype should be found in the optimized_model"
def _compare_shape(shape1, shape2):
if len(shape1.dim) != len(shape2.dim):
return False
for dim1, dim2 in zip(shape1.dim, shape2.dim):
if dim1.HasField("dim_value") and dim1.HasField("dim_value") and dim1.dim_value == dim2.dim_value:
continue
if dim1.HasField("dim_param") and dim1.HasField("dim_param") and dim1.dim_param == dim2.dim_param:
continue
return False
return True
assert output_dtypes[0] == onnx.TensorProto.INT64
assert len(output_shapes[0].dim) == 0
assert _compare_shape(input_shapes[0], output_shapes[1])
assert input_dtypes[0] == output_dtypes[1]
class CustomShapeInferFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * grad_output
@staticmethod
def infer_shape(node: onnx.NodeProto, tensor_input_shapes, tensor_input_dtypes):
return [tensor_input_shapes[0]], [tensor_input_dtypes[0]]
class TestModel(torch.nn.Module):
def __init__(self, output_size):
super().__init__()
self.custom_fn = CustomShapeInferFunction.apply
self.bias = Parameter(torch.empty(output_size, dtype=torch.float))
with torch.no_grad():
self.bias.uniform_()
def forward(self, model_input):
# model_input did not require_grad
out = self.custom_fn(model_input)
return out + self.bias
output_size = 1024
ortmodule = ORTModule(
TestModel(output_size),
).train()
_ = ortmodule(torch.randn(output_size, dtype=torch.float))
_check_pythonop_shape(ortmodule)