[ONNX] Refactor perfect/nearest match criteria to allow optional inputs and disallow mismatch attributes (#106478)

Fix #106057, except **Attribute dtype mismatch. E.g., alpha of aten.add.Tensor. -> Attribute: alpha INT vs FLOAT**.

Summarized the change
* Fill in defaults of attribute when `param_schema` is applied. This relaxes the matching on default attributes.
* Fill in None to optional input when `param_schema` is applied.
* Keep extra kwargs in attributes to make matching strictly.
* Allow input to be None when its dtype is `optiona[INPUT]`

The change comes with the guarantee from torchlib that attribute would never be None. For example, if `memory_format` is needed. The function should specify like this:
```python
@torch_op("aten::clone")
def aten_clone(
    self: TTensor, memory_format: str = ""  # pylint: disable=unused-argument
) -> TTensor:
    """clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor"""

    return op.Identity(self)
```

Previous to this PR, opSchema matching didn't strictly guard the number of inputs/attributes to allow nearest match, which introduces the bug of dispatching `aten::div.Tensor` to `aten::div.default` disregarding the fact that `aten::div.Tensor` has an extra attibute `rounding_mode`. This PR fixes the issue with the new logic to perfect/nearest match. Particularly, strictly restrict the qualification of being nearest match candidate.

For each ONNX variants, we check these step by step:
1. Check if the function signature of inputs number is the same as the inputs.
2. Check if the function signature of attribute names is the same set of inputs.

If either of the above two criteria fails to meet, the ONNX variant is not a perfect match, nor a nearest match candidate (match_score=None)

3. Check if input dtype matches
4. Check if attribute dtype matches

If 3 and 4 are met, then this is a perfect match, otherwise, it's still considered a candidate of nearest match with a matching score.

## Case Study

### Optional Input
The dispatcher recognizes optional inputs. However, the input can't be ignored. None must be provided.
```python
# Perfect match is found
inputs = (Tensor([2, 3]), None)
aten_op(X: TTensor, Y: Optional[INT64]):
    ...
```
Real Case: aten::convolution
NOTE: There is/will not optional attribute in torchlib.

### Different attributes
If an attribute is provided with value, it's a must to match the attribute in function signature.
```python
# Not perfect match, nor nearest match
inputs = (Tensor([2, 3]),)
attributes = {"a":1, "b":2}
aten_op(X: TTensor, a: int):
    ...
```
Real Case: aten::div and aten::div.Tensor_mode

### Default attribute
Default attribute will fill in the value into inputs/attributes
```python
# Perfect match is found
inputs = (Tensor([2, 3]),)
attributes = {}
aten_op(X: TTensor, a: int = 3):
    ...
```
Real case: aten::clone

### Ignore attribute with None value
The attributes with None value will be ignored in matching.
```python
# Perfect match
inputs = (Tensor([2, 3]),)
attributes = {"a": None}
aten_op(X: TTensor):
    ...

# Not perfect match, but eligible for nearest match
inputs = (Tensor([2, 3]),)
attributes = {"a": None}
aten_op(X: TTensor, a: int = 3):
    ...
```
Real case: aten::div and aten::div.Tensor_mode

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106478
Approved by: https://github.com/thiagocrepaldi, https://github.com/BowenBao
This commit is contained in:
AllenTiTaiWang 2023-08-10 00:37:54 +00:00 committed by PyTorch MergeBot
parent 4c1d8ab272
commit e93a90bdd5
10 changed files with 325 additions and 128 deletions

View file

@ -31,7 +31,7 @@ pip_install \
pip_install onnx-weekly==1.15.0.dev20230717
# TODO: change this when onnx-script is on testPypi
pip_install onnxscript-preview==0.1.0.dev20230801 --no-deps
pip_install onnxscript-preview==0.1.0.dev20230809 --no-deps
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/

View file

@ -220,9 +220,6 @@ class TestDispatcher(common_utils.TestCase):
)
# Non-registered op
internal_opname_class_unsupported = registration.OpName.from_name_parts(
namespace="aten", op_name="made_up_node", overload=None
)
unsupported_op_node = torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::made_up_node",
@ -246,7 +243,7 @@ class TestDispatcher(common_utils.TestCase):
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3), torch.tensor(4)),
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={},
),
name="nearest_match",
@ -257,7 +254,7 @@ class TestDispatcher(common_utils.TestCase):
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3), torch.tensor(4)),
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={"alpha": 1},
),
name="perfect_match_with_kwargs",
@ -270,11 +267,15 @@ class TestDispatcher(common_utils.TestCase):
custom_domain = onnxscript.values.Opset(domain="custom", version=1)
@onnxscript.script(custom_domain)
def test_custom_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat:
def test_custom_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
@onnxscript.script(custom_domain)
def test_default_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat:
def test_default_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
op_full_name = "test::test_op"
@ -306,7 +307,65 @@ class TestDispatcher(common_utils.TestCase):
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3), torch.tensor(4)),
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={"attr": None},
),
name="perfect_match_with_ignoring_none_attribute",
),
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={"unrelated": None},
),
name="perfect_match_with_ignoring_unrelated_none_attribute",
),
],
)
def test_find_the_perfect_or_nearest_match_onnxfunction_ignores_attribute_with_none(
self, node
):
custom_domain = onnxscript.values.Opset(domain="custom", version=1)
@onnxscript.script(custom_domain)
def test_op_attribute(
x: TCustomFloat, y: TCustomFloat, attr: int
) -> TCustomFloat:
return op.Add(x, y)
@onnxscript.script(custom_domain)
def test_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat:
return op.Add(x, y)
op_full_name = "test::test_op"
function_overloads = [
registration.ONNXFunction(test_op_attribute, op_full_name=op_full_name),
registration.ONNXFunction(test_op, op_full_name=op_full_name),
]
symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
node,
function_overloads,
node.args,
node.kwargs,
self.diagnostic_context,
)
self.assertEqual(symbolic_fn, test_op)
@common_utils.parametrize(
"node",
[
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={},
),
name="nearest_match",
@ -317,7 +376,7 @@ class TestDispatcher(common_utils.TestCase):
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3), torch.tensor(4)),
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={"alpha": 1},
),
name="perfect_match_with_kwargs",
@ -330,15 +389,21 @@ class TestDispatcher(common_utils.TestCase):
custom_domain = onnxscript.values.Opset(domain="custom", version=1)
@onnxscript.script(custom_domain)
def test_second_custom_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat:
def test_second_custom_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
@onnxscript.script(custom_domain)
def test_third_custom_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat:
def test_third_custom_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
@onnxscript.script(custom_domain)
def test_first_custom_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat:
def test_first_custom_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
op_full_name = "aten::add"
@ -458,24 +523,6 @@ class TestOpSchemaWrapper(common_utils.TestCase):
([torch.randn(3, 4), torch.tensor(3)], {}, ops.core.aten_new_full, 2),
name="match_2_inputs",
),
common_utils.subtest(
(
[torch.randn(3, 4), torch.tensor(3)],
{"dtype": 2}, # at this point, dtype should be converted to int
ops.core.aten_new_full,
1,
),
name="match_2_inputs_and_mismatch_1_kwarg",
),
common_utils.subtest(
(
[torch.randn(3, 4), torch.tensor(3)],
{},
ops.core.aten_new_full_dtype,
1,
),
name="match_2_input_and_mismatch_1_kwargs_optional",
),
common_utils.subtest(
(
[torch.randn(3, 4), torch.tensor(3)],

View file

@ -354,6 +354,10 @@ EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"),
),
xfail(
"nn.functional.adaptive_avg_pool1d",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten::div.Tensor_mode needs type promotion"),
),
xfail(
"nn.functional.adaptive_avg_pool2d",
reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \
@ -488,9 +492,10 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
matcher=lambda sample: sample.input[0].equal(torch.tensor([])),
reason="core dump - cat does not support zero-dim tensors yet",
),
skip(
xfail(
"div",
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None
and sample.input.dtype in onnx_test_common.INT_TYPES,
reason="rounding_mode is not yet supported",
),
xfail(

View file

@ -45,6 +45,7 @@ def assert_has_diagnostics(
)
@common_utils.instantiate_parametrized_tests
class TestFxToOnnx(pytorch_test_common.ExportTestCase):
def setUp(self):
super().setUp()
@ -82,7 +83,20 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
self.assertNotIsInstance(tensor_x, fake_tensor.FakeTensor)
self.assertNotIsInstance(tensor_y, fake_tensor.FakeTensor)
def test_mnist_exported_with_no_warnings_on_get_attr_node_in_op_level_debug(self):
@common_utils.parametrize(
"diagnostic_rule",
[
common_utils.subtest(
diagnostics.rules.find_opschema_matched_symbolic_function,
name="optional_inputs",
),
common_utils.subtest(
diagnostics.rules.op_level_debugging,
name="get_attr_node_in_op_level_debug",
),
],
)
def test_mnist_exported_with_no_warnings(self, diagnostic_rule):
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
@ -109,12 +123,9 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
MNISTModel(), tensor_x, export_options=ExportOptions(op_level_debug=True)
)
# NOTE: This additional test makes sure that op level debug supports `get_attr`
# fx.Node, also known as weight in PyTorch. aten.convolution.default is one of
# the nodes that has weight attribute.
assert_has_diagnostics(
export_output.diagnostic_context,
diagnostics.rules.op_level_debugging,
diagnostic_rule,
diagnostics.levels.NONE,
expected_node="aten.convolution.default",
)
@ -144,7 +155,9 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
return torch.sum(values)
x = torch.arange(1.0, 6.0, requires_grad=True)
_ = dynamo_export(TopKModel(), x, export_options=self.export_options)
export_output = dynamo_export(
TopKModel(), x, export_options=self.export_options
)
def test_unsupported_indices_fake_tensor_generated_with_op_level_debug(self):
class EmbedModelWithoutPaddingIdx(torch.nn.Module):
@ -238,6 +251,20 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
expected_node="aten.add.Tensor",
)
def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self):
class CustomModule(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.clone(input, memory_format=torch.preserve_format)
x = torch.tensor(3)
export_output = dynamo_export(CustomModule(), x)
assert_has_diagnostics(
export_output.diagnostic_context,
diagnostics.rules.find_opschema_matched_symbolic_function,
diagnostics.levels.NONE,
expected_node="aten.clone.default",
)
def test_dynamo_export_retains_readable_parameter_and_buffer_names(self):
class SubModule(torch.nn.Module):
def __init__(self):

View file

@ -657,7 +657,6 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
*fake_args,
export_options=export_options,
)
onnx_model = export_output.model_proto
onnx_test_common.assert_dynamic_shapes(export_output, self.dynamic_shapes)

View file

@ -124,6 +124,11 @@ def _bool(obj: bool) -> str:
return str(obj)
@_format_argument.register
def _str(obj: str) -> str:
return obj
@_format_argument.register
def _registration_onnx_function(obj: registration.ONNXFunction) -> str:
# TODO: Compact display of `param_schema`.

View file

@ -597,7 +597,6 @@ class FxOnnxInterpreter:
# Dispatch to ONNX op through OpShema. The input argument dtypes are compared to
# function signature in OpSchema, and find the best matched overload.
# TODO(titaiwang): diagnostic rules.
symbolic_fn = onnxfunction_dispatcher.dispatch(
node=node,
onnx_args=onnx_args,

View file

@ -31,6 +31,7 @@ if TYPE_CHECKING:
from torch.onnx import OnnxRegistry
# For beartype
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
graph_building as onnxscript_graph_building,
@ -67,7 +68,7 @@ def _find_operator_overloads_in_onnx_registry_disagnostic_message_formatter(
class OnnxFunctionDispatcher:
"""A dispatcher that finds the best ONNX Function for ATen operators.
"""A dispatcher that finds the best ONNX Function for ATen/Custom operators.
It uses the `torch.ops` name to find the function. If not found, it falls back to default.
Otherwise, the best match is found among all function overloads. An exact match has
@ -85,6 +86,9 @@ class OnnxFunctionDispatcher:
the potential wrongly annotated dtypes and attributes matching, we use
nearest match to find the best function once the aten name is targeted.
3. Tie-breaker: If there are multiple nearest matches, we will select the one with
the highest matching score.
NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged.
"""
@ -219,8 +223,7 @@ class OnnxFunctionDispatcher:
Raises:
RuntimeError: If there are no overloaded functions available for the given FX node.
"""
# TODO(justinchuby): Cache the OnnxSchemaChecker so we don't need to run the init logic everytime
overload_match_ranking: Dict[registration.ONNXFunction, int] = {}
overload_match_ranking: Dict[registration.ONNXFunction, Optional[int]] = {}
diagnostic = diagnostic_context.inflight_diagnostic()
# Iterate the overloaded functions in reverse order to prioritize the custom ones
@ -228,23 +231,42 @@ class OnnxFunctionDispatcher:
for symbolic_function in reversed(default_and_custom_functions):
function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function)
# NOTE: 1. If the perfect match is found, return the function
if function_opschema.perfect_match_inputs(
diagnostic, onnx_args, onnx_kwargs
):
# If the perfect match is found, return the function
return symbolic_function.onnx_function
# Record the match score for the nearest match if it's not the perfect match
overload_match_ranking[symbolic_function] = function_opschema.match_score
# NOTE: If the perfect match is not found, find the nearest match
# NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates
# If there is no nearest match, raise an error
overload_match_ranking = {
k: v for k, v in overload_match_ranking.items() if v is not None
}
if not overload_match_ranking:
# If there are no overloaded functions available for the given FX node, raise an
# unsupported error
op_full_name = self._get_aten_name(
node, diagnostic_context
).qualified_name()
diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
f"Cannot find any perfect/nearest match of symbolic function for {op_full_name},"
f"which should be registered under {node.target}.",
unsupported_fx_node=node,
)
diagnostic_context.log(diagnostic)
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
diagnostic.with_additional_message(
"### Exact match is not found!\n"
"Cannot find a perfect match of symbolic overload, "
"a nearest match is found. Please check the ONNX output carefully. \n",
)
diagnostic.level = diagnostics.levels.WARNING
# NOTE: Tie breaker: if there are multiple nearest matches, we will choose the one
# NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one
# that is custom first. If there are multiple custom ones, we will choose the one
# that is added lastly in the list.
symbolic_function_list: List[registration.ONNXFunction] = sorted(
@ -355,7 +377,7 @@ class OnnxFunctionDispatcher:
node=node, diagnostic_context=diagnostic_context
)
# NOTE: If the ATen/Custom operators are not registered, the group will be None.
# If the ATen/Custom operators are not registered, the group will be None.
# And non-registerd ATen/Custom operators will trigger error in the next step.
function_group: Optional[List[registration.ONNXFunction]] = None
@ -374,10 +396,6 @@ class OnnxFunctionDispatcher:
overload=None,
)
if function_group is not None:
# NOTE: Currently, most of torchlib functions are not registered with overload
# in ONNX registry. So we will only log a warning in SARIF if we can't find the overload
# to avoid spammy warnings in printout.
# TODO: https://github.com/microsoft/onnxscript/issues/828
op_full_name = internal_opname.qualified_name()
diagnostic = diagnostic_context.inflight_diagnostic()
diagnostic.with_additional_message(
@ -387,15 +405,13 @@ class OnnxFunctionDispatcher:
)
diagnostic.level = diagnostics.levels.WARNING
# NOTE: If the ATen/Custom operators are not registered, the group will be None.
if function_group is not None:
# If the input has complex dtype, we will only dispatch to the complex functions.
# NOTE: If the input has complex dtype, we will only dispatch to the complex functions.
function_group = self._filter_or_keep_complex(
node, function_group, diagnostic_context
)
return function_group # type: ignore[return-value]
# If we can't find the function group, raise error.
op_full_name = internal_opname.qualified_name()
diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
diagnostics.rules.no_symbolic_function_for_call_function,
@ -431,38 +447,84 @@ class _OnnxSchemaChecker:
It provides methods to check for input compatibility based on the OpSchema. It also
provides a matching score to indicate how well the OpSchema matches the input and
kwargs types.
kwargs types. A function will be evaluated as perfect match, nearest match eligible,
or no match.
There are three types of ONNX overloads in torchlib:
Here are some common examples in categories:
1. Different types: Caused by the difference between the ONNX spec and PyTorch.The
matching system finds the correct one.
1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as
the OpSchema. The types of inputs and attributes are exactly the same as the
OpSchema.
```python
@torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
inputs = (Tensor[2, 3], Tensor[2, 3])
attributes = {"alpha": 1.0}
@torch_op("aten::op")
def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal:
...
@torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
...
```
```
Result: Perfect match.
2. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype
would not be None.
2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However,
the input can't be ignored. None must be provided.
```python
@torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
...
inputs = (Tensor([2, 3]), None)
attributes = {}
@torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
aten_op(X: TTensor, Y: Optional[INT64]):
...
```
Result: Perfect match.
Real example: `aten::convolution`.
Depends on dtype is provided or not, matching system will dispatch the ATen op to
the correct one.
3. [NOTE: Different attributes]: If an attribute is provided with value, it's
a must to match the attribute in function signature.
```python
inputs = (Tensor([2, 3]),)
attributes = {"a":1, "b":2}
aten_op(X: TTensor, a: int):
...
```
Result: No match.
Real example: `aten::div` vs `aten::div.Tensor_mode`.
4. [NOTE: Default attributes]: Default attribute will fill in the value into
inputs/attributes.
```python
inputs = (Tensor([2, 3]),)
attributes = {}
aten_op(X: TTensor, a: int = 3):
...
```
Result: Perfect match.
Real example: `aten::clone`
5. [NOTE: Ignore attribute with None value]: The attributes with None value
will be ignored in matching.
```python
inputs = (Tensor([2, 3]),)
attributes = {"a": None}
aten_op(X: TTensor):
...
```
Result: Perfect match.
```python
inputs = (Tensor([2, 3]),)
attributes = {"a": None}
aten_op(X: TTensor, a: int = 3):
...
```
Result: Nearest match eligible.
Real example: `aten::div` vs `aten::div.Tensor_mode`.
Attributes:
onnxfunction: The OnnxFunction.
@ -496,12 +558,15 @@ class _OnnxSchemaChecker:
for constraint in self.op_schema.type_constraints
}
self.attributes = self.op_schema.attributes
self._matching_score: int = 0
self._matching_score: Optional[int] = None
@property
def match_score(self) -> int:
def match_score(self) -> Optional[int]:
"""The matching score of the OnnxSchemaChecker .
If this remains None, it means the matching score has not been calculated,
and it's not a nearest match candidate.
Returns:
The matching score of the OnnxSchemaChecker .
"""
@ -522,6 +587,13 @@ class _OnnxSchemaChecker:
constraints and the number of inputs matches the number of inputs in the
OpSchema.
Checking steps:
1. The function signature matches the inputs number, and attribute names.
2. The input/attribute types are all in the type constraints.
A function should at least pass the first step to be eligible for the
nearest matching.
Args:
diagnostic: The diagnostic to use for logging detailed info.
args: The input arguments organized in PyTorch inputs way.
@ -540,30 +612,54 @@ class _OnnxSchemaChecker:
self.param_schema,
args,
kwargs,
fill_defaults=False, # NOTE: We don't want to change inputs
fill_defaults=True, # fill defaults for optional arguments to match
)
# TODO(titaiwang): Currently the functions in torchlib are manully annotated,
# so there are quite a few functions that wrongly annotated or strctly annotated.
# The matching system relax the match while we fix them in the future.
self._record_matching_score(function_inputs, function_attributes)
diagnostic.with_additional_message("### Checking perfect match...\n")
diagnostic.with_additional_message(
f"{diagnostics.format_argument(self.onnxfunction)}"
)
diagnostic.with_additional_message(f"match score: {self.match_score}\n")
# NOTE: 1. Check if the input number and attribute names match the
# OpSchema. If it's not, we know the function is not eligible to be a perfect
# match, nor a nearest match.
# We use is_perfect_match to postpone the return value to the end
# of the function, as we want to log all the mismatch info.
is_perfect_match = True
if len(function_inputs) != len(self.op_schema.inputs):
diagnostic.with_additional_message(
f"#### Failed: input number mismatch! \n"
f"Actual {len(function_inputs)} vs expected {len(self.op_schema.inputs)}\n"
)
diagnostic.with_additional_message(
"The function is not a nearest match candidate.\n"
)
is_perfect_match = False
if set(function_attributes) != set(self.attributes):
diagnostic.with_additional_message(
f"#### Failed: attribute mismatch! \n"
f"Actual {set(function_attributes)} vs\n"
f"expected {set(self.attributes)}\n"
)
diagnostic.with_additional_message(
"The function is not a nearest match candidate.\n"
)
is_perfect_match = False
# If it's already not a perfect match, we can return False directly. Further
# checking is only for the functions that are eligible for nearest match.
if not is_perfect_match:
return False
# NOTE: 2. The dtypes of inputs and attributes should be in the
# type constraints of the OpSchema. If they are not, we know the function is not
# eligible to be a perfect match, but can be a nearest match candidate.
for schema_input, torch_input in zip(self.op_schema.inputs, function_inputs):
torch_input_compatible_types = _find_onnx_data_type(torch_input)
allowed_types = self.type_constraints[schema_input.type_str]
if not allowed_types.intersection(torch_input_compatible_types):
if not allowed_types.intersection(torch_input_compatible_types) and not any(
fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str)
for onnx_type_str in allowed_types
):
# If torch_input_compatible_types isn't in allowed_types
# of this input defined in the OpSchema, we know the function
# and the input are not compatible
@ -572,18 +668,8 @@ class _OnnxSchemaChecker:
f"Actual {torch_input_compatible_types} vs\n"
f"expected {allowed_types}\n"
)
return False
# Check attributes keys are the same
if set(function_attributes) != set(self.attributes):
# If the attributes of the OpSchema and the attributes don't match,
# we know the function and the input are not compatible
diagnostic.with_additional_message(
f"#### Failed: attribute mismatch! \n"
f"Actual {set(function_attributes)} vs\n"
f"expected {set(self.attributes)}\n"
)
return False
# Check attribute dtypes
is_perfect_match = False
for attribute_name, attribute in function_attributes.items():
if not self._match_onnx_attribute_type(attribute_name, attribute):
# If the attribute type of the OpSchema and the attribute type don't match,
@ -593,8 +679,12 @@ class _OnnxSchemaChecker:
f"Actual {type(attribute)} vs\n"
f"expected {self.attributes[attribute_name].type}\n"
)
return False
return True
is_perfect_match = False
# NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype.
self._record_matching_score(function_inputs, function_attributes)
diagnostic.with_additional_message(f"match score: {self.match_score}\n")
return is_perfect_match
@_beartype.beartype
def _match_onnx_attribute_type(
@ -632,13 +722,16 @@ class _OnnxSchemaChecker:
):
"""Calculate the inputs matching score of the OpSchema requirements to find the nearest match.
Only the functions which have the same number of inputs and attributes as the
OpSchema are eligible to be a nearest match candidate. Thus, we don't need to
check the length of inputs and attributes here, and only check the types of
inputs and attributes.
How the matchsing score is calculated:
1. score += 1 if one input type is in the type constraints.
2. score -= 1 if one kwarg is not symmetrically the same.
score += 1 if one input/attribute type is in the type constraints.
Limitations:
1. An Overload is punished if it doesn't have `default` attributes.
2. None/NoeType/[] could result in zero matches, and the same score of overloads,
None/NoeType/[] could result in zero matches, and the same score of overloads,
which will be recorded in SARIF.
Args:
@ -648,7 +741,7 @@ class _OnnxSchemaChecker:
Returns:
True if the inputs match the requirements, False otherwise.
"""
self._matching_score = 0
# If they have different length of arguments, the score would be lower to those
# functions which have the same length of arguments.
for schema_input, torch_input in zip(self.op_schema.inputs, inputs):
@ -661,11 +754,6 @@ class _OnnxSchemaChecker:
self._matching_score += 1
# NOTE: The penalty is applied to those functions which have different attributes.
for attribute_name, attribute_proto in self.attributes.items():
if attribute_name not in attributes:
# If the attribute of the OpSchema and the attribute don't match,
# we know the function and the input are not compatible
self._matching_score -= 1
continue
attribute = attributes[attribute_name]
attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type(
type(attribute)
@ -674,10 +762,6 @@ class _OnnxSchemaChecker:
# If the attribute type of the OpSchema and the attribute type don't match,
# we know the function and the input are not compatible
self._matching_score -= 1
# If there is any unexpected attribute in attributes, we know the function
# and the input are not compatible
extra_attrbute_counts = set(attributes).difference(set(self.attributes))
self._matching_score -= len(extra_attrbute_counts)
# NOTE: Referenced from onnxscript internal function.
# Importing this function makes the code less robust, as it is not a public API.
@ -693,6 +777,12 @@ class _OnnxSchemaChecker:
) -> Tuple[List[Any], Dict[str, Any]]:
"""Separate Python args and kwargs into ONNX inputs and attributes.
Extra_kwargs are ignored if their values are None. For example, if the
OpSchema has an attribute "rounding_mode" and the caller provides
"rounding_mode=None", the attribute "rounding_mode" will not be included
in the returned attributes when the OnnxFunction signature doesn't have
"rounding_mode" as an attribute.
Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
args: The Python positional arguments supplied by the caller.
@ -711,9 +801,13 @@ class _OnnxSchemaChecker:
# args, kwargs and param_schemas should be all in order
# user may not specify all inputs or attributes
# TODO: avoid circular dependency
import onnx
onnx_inputs: List[Any] = []
onnx_attributes: Dict[str, Any] = dict()
# NOTE: We need to copy kwargs because we will mutate it
copy_kwargs = kwargs.copy()
for i, param in enumerate(param_schemas):
if param.is_variadic_input:
# Exhaust all remaining args
@ -725,16 +819,32 @@ class _OnnxSchemaChecker:
onnx_inputs.append(args[i])
else:
onnx_attributes[param.name] = args[i]
elif param.name in kwargs:
elif param.name in copy_kwargs:
if param.is_input:
onnx_inputs.append(kwargs[param.name])
# Move the input from kwargs to inputs
onnx_inputs.append(copy_kwargs[param.name])
copy_kwargs.pop(param.name)
else:
onnx_attributes[param.name] = kwargs[param.name]
elif param.is_attribute and param.default is not object():
onnx_attributes[param.name] = copy_kwargs[param.name]
elif (
param.is_attribute
and self.attributes[param.name].default_value.type
!= onnx.AttributeProto.UNDEFINED # type: ignore[attr-defined]
):
# User did not provide the attribute
if fill_defaults:
onnx_attributes[param.name] = param.default
# optional input
elif param.is_input:
# TODO: support optional input default in onnx-script?
if fill_defaults:
onnx_inputs.append(None)
# NOTE: Pick up extra kwargs if it's not None. None is not expected
# as an attribute value in torchlib.
for k, v in copy_kwargs.items():
if k not in onnx_attributes and v is not None:
onnx_attributes[k] = v
return onnx_inputs, onnx_attributes

View file

@ -58,7 +58,6 @@ def validate_op_between_ort_torch(
There are three signs can be found:
1. Blue: Pass
2. Yellow: Bypass
3. Red: Fail
Args:
node (torch.fx.Node): The validated fx.node
@ -116,7 +115,6 @@ def validate_op_between_ort_torch(
symbolic_fn.param_schemas(),
torch_args,
torch_kwargs,
fill_defaults=False,
allow_extra_kwargs=True,
)
# NOTE: Apply kwargs preprocessing AFTER they are split
@ -310,16 +308,17 @@ def _convert_torch_args_to_onnxfunction_args(
param_schemas: Sequence[onnxscript.values.ParamSchema],
args: List[fx_type_utils.Argument],
kwargs: Dict[str, fx_type_utils.Argument],
fill_defaults: bool = False,
allow_extra_kwargs: bool = False,
) -> Tuple[List[Any], Dict[str, Any],]:
"""Convert Python args and kwargs to OnnxFunction acceptable with matching ONNX ParamSchema.
NOTE: This is different from the param_schema separating in dispatcher, since at this point
we are already sure that the args and kwargs are in order and matched.
Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
args: The Python positional arguments supplied by the caller.
kwargs: The Python keyword arguments supplied by the caller.
fill_defaults: Whether to fill the default values for attributes.
allow_extra_kwargs: Whether to allow extra keyword arguments.
When set to True, extra/unknown arguments will be ignored.
@ -359,10 +358,6 @@ def _convert_torch_args_to_onnxfunction_args(
tagged_kwargs[param.name] = _convert_tensor_to_numpy(kwargs[param.name])
else:
tagged_kwargs[param.name] = kwargs[param.name]
elif param.default is not object():
# User did not provide the input/attribute
if fill_defaults:
tagged_kwargs[param.name] = param.default
elif param.required:
raise TypeError(f"Required input/attribute '{param}' was not provided")

View file

@ -47,6 +47,10 @@ def from_sym_value_to_torch_dtype(sym_value: SYM_VALUE_TYPE) -> torch.dtype:
return _SYM_TYPE_TO_TORCH_DTYPE[type(sym_value)]
def is_optional_onnx_dtype_str(onnx_type_str: str) -> bool:
return onnx_type_str in _OPTIONAL_ONNX_DTYPE_STR
def from_torch_dtype_to_onnx_dtype_str(dtype: Union[torch.dtype, type]) -> Set[str]:
return _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS[dtype]
@ -113,6 +117,12 @@ _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: Dict[
torch.complex128: {"tensor(double)"},
}
_OPTIONAL_ONNX_DTYPE_STR: Set[str] = {
f"optional({value})"
for value_set in _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS.values()
for value in value_set
}
_PYTHON_TYPE_TO_TORCH_DTYPE = {
bool: torch.bool,
int: torch.int64,