diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index 9ab35f198b0..096dd092ad7 100644 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -31,7 +31,7 @@ pip_install \ pip_install onnx-weekly==1.15.0.dev20230717 # TODO: change this when onnx-script is on testPypi -pip_install onnxscript-preview==0.1.0.dev20230801 --no-deps +pip_install onnxscript-preview==0.1.0.dev20230809 --no-deps # Cache the transformers model to be used later by ONNX tests. We need to run the transformers # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ diff --git a/test/onnx/dynamo/test_registry_dispatcher.py b/test/onnx/dynamo/test_registry_dispatcher.py index 52e6fbdc4fd..208615bef2c 100644 --- a/test/onnx/dynamo/test_registry_dispatcher.py +++ b/test/onnx/dynamo/test_registry_dispatcher.py @@ -220,9 +220,6 @@ class TestDispatcher(common_utils.TestCase): ) # Non-registered op - internal_opname_class_unsupported = registration.OpName.from_name_parts( - namespace="aten", op_name="made_up_node", overload=None - ) unsupported_op_node = torch.fx.Node( graph=torch.fx.Graph(), name="aten::made_up_node", @@ -246,7 +243,7 @@ class TestDispatcher(common_utils.TestCase): name="aten::add.Tensor", op="call_function", target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] - args=(torch.tensor(3), torch.tensor(4)), + args=(torch.tensor(3.0), torch.tensor(4.0)), kwargs={}, ), name="nearest_match", @@ -257,7 +254,7 @@ class TestDispatcher(common_utils.TestCase): name="aten::add.Tensor", op="call_function", target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] - args=(torch.tensor(3), torch.tensor(4)), + args=(torch.tensor(3.0), torch.tensor(4.0)), kwargs={"alpha": 1}, ), name="perfect_match_with_kwargs", @@ -270,11 +267,15 @@ class TestDispatcher(common_utils.TestCase): custom_domain = onnxscript.values.Opset(domain="custom", version=1) @onnxscript.script(custom_domain) - def test_custom_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat: + def test_custom_op( + x: TCustomFloat, y: TCustomFloat, alpha: int = 1 + ) -> TCustomFloat: return op.Add(x, y) @onnxscript.script(custom_domain) - def test_default_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat: + def test_default_op( + x: TCustomFloat, y: TCustomFloat, alpha: int = 1 + ) -> TCustomFloat: return op.Add(x, y) op_full_name = "test::test_op" @@ -306,7 +307,65 @@ class TestDispatcher(common_utils.TestCase): name="aten::add.Tensor", op="call_function", target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] - args=(torch.tensor(3), torch.tensor(4)), + args=(torch.tensor(3.0), torch.tensor(4.0)), + kwargs={"attr": None}, + ), + name="perfect_match_with_ignoring_none_attribute", + ), + common_utils.subtest( + torch.fx.Node( + graph=torch.fx.Graph(), + name="aten::add.Tensor", + op="call_function", + target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] + args=(torch.tensor(3.0), torch.tensor(4.0)), + kwargs={"unrelated": None}, + ), + name="perfect_match_with_ignoring_unrelated_none_attribute", + ), + ], + ) + def test_find_the_perfect_or_nearest_match_onnxfunction_ignores_attribute_with_none( + self, node + ): + custom_domain = onnxscript.values.Opset(domain="custom", version=1) + + @onnxscript.script(custom_domain) + def test_op_attribute( + x: TCustomFloat, y: TCustomFloat, attr: int + ) -> TCustomFloat: + return op.Add(x, y) + + @onnxscript.script(custom_domain) + def test_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat: + return op.Add(x, y) + + op_full_name = "test::test_op" + + function_overloads = [ + registration.ONNXFunction(test_op_attribute, op_full_name=op_full_name), + registration.ONNXFunction(test_op, op_full_name=op_full_name), + ] + + symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction( + node, + function_overloads, + node.args, + node.kwargs, + self.diagnostic_context, + ) + self.assertEqual(symbolic_fn, test_op) + + @common_utils.parametrize( + "node", + [ + common_utils.subtest( + torch.fx.Node( + graph=torch.fx.Graph(), + name="aten::add.Tensor", + op="call_function", + target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] + args=(torch.tensor(3.0), torch.tensor(4.0)), kwargs={}, ), name="nearest_match", @@ -317,7 +376,7 @@ class TestDispatcher(common_utils.TestCase): name="aten::add.Tensor", op="call_function", target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] - args=(torch.tensor(3), torch.tensor(4)), + args=(torch.tensor(3.0), torch.tensor(4.0)), kwargs={"alpha": 1}, ), name="perfect_match_with_kwargs", @@ -330,15 +389,21 @@ class TestDispatcher(common_utils.TestCase): custom_domain = onnxscript.values.Opset(domain="custom", version=1) @onnxscript.script(custom_domain) - def test_second_custom_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat: + def test_second_custom_op( + x: TCustomFloat, y: TCustomFloat, alpha: int = 1 + ) -> TCustomFloat: return op.Add(x, y) @onnxscript.script(custom_domain) - def test_third_custom_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat: + def test_third_custom_op( + x: TCustomFloat, y: TCustomFloat, alpha: int = 1 + ) -> TCustomFloat: return op.Add(x, y) @onnxscript.script(custom_domain) - def test_first_custom_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat: + def test_first_custom_op( + x: TCustomFloat, y: TCustomFloat, alpha: int = 1 + ) -> TCustomFloat: return op.Add(x, y) op_full_name = "aten::add" @@ -458,24 +523,6 @@ class TestOpSchemaWrapper(common_utils.TestCase): ([torch.randn(3, 4), torch.tensor(3)], {}, ops.core.aten_new_full, 2), name="match_2_inputs", ), - common_utils.subtest( - ( - [torch.randn(3, 4), torch.tensor(3)], - {"dtype": 2}, # at this point, dtype should be converted to int - ops.core.aten_new_full, - 1, - ), - name="match_2_inputs_and_mismatch_1_kwarg", - ), - common_utils.subtest( - ( - [torch.randn(3, 4), torch.tensor(3)], - {}, - ops.core.aten_new_full_dtype, - 1, - ), - name="match_2_input_and_mismatch_1_kwargs_optional", - ), common_utils.subtest( ( [torch.randn(3, 4), torch.tensor(3)], diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 9a4306528d9..6a02a08fca6 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -354,6 +354,10 @@ EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = ( dtypes=(torch.uint8, torch.int8, torch.int16,), reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"), ), + xfail( + "nn.functional.adaptive_avg_pool1d", + reason=onnx_test_common.reason_onnx_script_does_not_support("aten::div.Tensor_mode needs type promotion"), + ), xfail( "nn.functional.adaptive_avg_pool2d", reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \ @@ -488,9 +492,10 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = ( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="core dump - cat does not support zero-dim tensors yet", ), - skip( + xfail( "div", - matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, + matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None + and sample.input.dtype in onnx_test_common.INT_TYPES, reason="rounding_mode is not yet supported", ), xfail( diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index b8825ef2883..eede300e925 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -45,6 +45,7 @@ def assert_has_diagnostics( ) +@common_utils.instantiate_parametrized_tests class TestFxToOnnx(pytorch_test_common.ExportTestCase): def setUp(self): super().setUp() @@ -82,7 +83,20 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase): self.assertNotIsInstance(tensor_x, fake_tensor.FakeTensor) self.assertNotIsInstance(tensor_y, fake_tensor.FakeTensor) - def test_mnist_exported_with_no_warnings_on_get_attr_node_in_op_level_debug(self): + @common_utils.parametrize( + "diagnostic_rule", + [ + common_utils.subtest( + diagnostics.rules.find_opschema_matched_symbolic_function, + name="optional_inputs", + ), + common_utils.subtest( + diagnostics.rules.op_level_debugging, + name="get_attr_node_in_op_level_debug", + ), + ], + ) + def test_mnist_exported_with_no_warnings(self, diagnostic_rule): class MNISTModel(nn.Module): def __init__(self): super().__init__() @@ -109,12 +123,9 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase): MNISTModel(), tensor_x, export_options=ExportOptions(op_level_debug=True) ) - # NOTE: This additional test makes sure that op level debug supports `get_attr` - # fx.Node, also known as weight in PyTorch. aten.convolution.default is one of - # the nodes that has weight attribute. assert_has_diagnostics( export_output.diagnostic_context, - diagnostics.rules.op_level_debugging, + diagnostic_rule, diagnostics.levels.NONE, expected_node="aten.convolution.default", ) @@ -144,7 +155,9 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase): return torch.sum(values) x = torch.arange(1.0, 6.0, requires_grad=True) - _ = dynamo_export(TopKModel(), x, export_options=self.export_options) + export_output = dynamo_export( + TopKModel(), x, export_options=self.export_options + ) def test_unsupported_indices_fake_tensor_generated_with_op_level_debug(self): class EmbedModelWithoutPaddingIdx(torch.nn.Module): @@ -238,6 +251,20 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase): expected_node="aten.add.Tensor", ) + def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self): + class CustomModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.clone(input, memory_format=torch.preserve_format) + + x = torch.tensor(3) + export_output = dynamo_export(CustomModule(), x) + assert_has_diagnostics( + export_output.diagnostic_context, + diagnostics.rules.find_opschema_matched_symbolic_function, + diagnostics.levels.NONE, + expected_node="aten.clone.default", + ) + def test_dynamo_export_retains_readable_parameter_and_buffer_names(self): class SubModule(torch.nn.Module): def __init__(self): diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 08fd6392dc2..a7b6ecebb8d 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -657,7 +657,6 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime): *fake_args, export_options=export_options, ) - onnx_model = export_output.model_proto onnx_test_common.assert_dynamic_shapes(export_output, self.dynamic_shapes) diff --git a/torch/onnx/_internal/fx/diagnostics.py b/torch/onnx/_internal/fx/diagnostics.py index a8dd2e78835..f09cf590639 100644 --- a/torch/onnx/_internal/fx/diagnostics.py +++ b/torch/onnx/_internal/fx/diagnostics.py @@ -124,6 +124,11 @@ def _bool(obj: bool) -> str: return str(obj) +@_format_argument.register +def _str(obj: str) -> str: + return obj + + @_format_argument.register def _registration_onnx_function(obj: registration.ONNXFunction) -> str: # TODO: Compact display of `param_schema`. diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py index f8c3d3efc10..6b66c38397b 100644 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -597,7 +597,6 @@ class FxOnnxInterpreter: # Dispatch to ONNX op through OpShema. The input argument dtypes are compared to # function signature in OpSchema, and find the best matched overload. - # TODO(titaiwang): diagnostic rules. symbolic_fn = onnxfunction_dispatcher.dispatch( node=node, onnx_args=onnx_args, diff --git a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py index 61b48c822ee..170112856fa 100644 --- a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py +++ b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: from torch.onnx import OnnxRegistry + # For beartype from onnxscript.function_libs.torch_lib import ( # type: ignore[import] graph_building as onnxscript_graph_building, @@ -67,7 +68,7 @@ def _find_operator_overloads_in_onnx_registry_disagnostic_message_formatter( class OnnxFunctionDispatcher: - """A dispatcher that finds the best ONNX Function for ATen operators. + """A dispatcher that finds the best ONNX Function for ATen/Custom operators. It uses the `torch.ops` name to find the function. If not found, it falls back to default. Otherwise, the best match is found among all function overloads. An exact match has @@ -85,6 +86,9 @@ class OnnxFunctionDispatcher: the potential wrongly annotated dtypes and attributes matching, we use nearest match to find the best function once the aten name is targeted. + 3. Tie-breaker: If there are multiple nearest matches, we will select the one with + the highest matching score. + NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged. """ @@ -219,8 +223,7 @@ class OnnxFunctionDispatcher: Raises: RuntimeError: If there are no overloaded functions available for the given FX node. """ - # TODO(justinchuby): Cache the OnnxSchemaChecker so we don't need to run the init logic everytime - overload_match_ranking: Dict[registration.ONNXFunction, int] = {} + overload_match_ranking: Dict[registration.ONNXFunction, Optional[int]] = {} diagnostic = diagnostic_context.inflight_diagnostic() # Iterate the overloaded functions in reverse order to prioritize the custom ones @@ -228,23 +231,42 @@ class OnnxFunctionDispatcher: for symbolic_function in reversed(default_and_custom_functions): function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function) + # NOTE: 1. If the perfect match is found, return the function if function_opschema.perfect_match_inputs( diagnostic, onnx_args, onnx_kwargs ): - # If the perfect match is found, return the function return symbolic_function.onnx_function # Record the match score for the nearest match if it's not the perfect match overload_match_ranking[symbolic_function] = function_opschema.match_score - # NOTE: If the perfect match is not found, find the nearest match + # NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates + # If there is no nearest match, raise an error + overload_match_ranking = { + k: v for k, v in overload_match_ranking.items() if v is not None + } + if not overload_match_ranking: + # If there are no overloaded functions available for the given FX node, raise an + # unsupported error + op_full_name = self._get_aten_name( + node, diagnostic_context + ).qualified_name() + diagnostic = diagnostics.UnsupportedFxNodeDiagnostic( + diagnostics.rules.no_symbolic_function_for_call_function, + diagnostics.levels.ERROR, + f"Cannot find any perfect/nearest match of symbolic function for {op_full_name}," + f"which should be registered under {node.target}.", + unsupported_fx_node=node, + ) + diagnostic_context.log(diagnostic) + raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic) + diagnostic.with_additional_message( "### Exact match is not found!\n" "Cannot find a perfect match of symbolic overload, " "a nearest match is found. Please check the ONNX output carefully. \n", ) diagnostic.level = diagnostics.levels.WARNING - - # NOTE: Tie breaker: if there are multiple nearest matches, we will choose the one + # NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one # that is custom first. If there are multiple custom ones, we will choose the one # that is added lastly in the list. symbolic_function_list: List[registration.ONNXFunction] = sorted( @@ -355,7 +377,7 @@ class OnnxFunctionDispatcher: node=node, diagnostic_context=diagnostic_context ) - # NOTE: If the ATen/Custom operators are not registered, the group will be None. + # If the ATen/Custom operators are not registered, the group will be None. # And non-registerd ATen/Custom operators will trigger error in the next step. function_group: Optional[List[registration.ONNXFunction]] = None @@ -374,10 +396,6 @@ class OnnxFunctionDispatcher: overload=None, ) if function_group is not None: - # NOTE: Currently, most of torchlib functions are not registered with overload - # in ONNX registry. So we will only log a warning in SARIF if we can't find the overload - # to avoid spammy warnings in printout. - # TODO: https://github.com/microsoft/onnxscript/issues/828 op_full_name = internal_opname.qualified_name() diagnostic = diagnostic_context.inflight_diagnostic() diagnostic.with_additional_message( @@ -387,15 +405,13 @@ class OnnxFunctionDispatcher: ) diagnostic.level = diagnostics.levels.WARNING - # NOTE: If the ATen/Custom operators are not registered, the group will be None. if function_group is not None: - # If the input has complex dtype, we will only dispatch to the complex functions. + # NOTE: If the input has complex dtype, we will only dispatch to the complex functions. function_group = self._filter_or_keep_complex( node, function_group, diagnostic_context ) return function_group # type: ignore[return-value] - # If we can't find the function group, raise error. op_full_name = internal_opname.qualified_name() diagnostic = diagnostics.UnsupportedFxNodeDiagnostic( diagnostics.rules.no_symbolic_function_for_call_function, @@ -431,38 +447,84 @@ class _OnnxSchemaChecker: It provides methods to check for input compatibility based on the OpSchema. It also provides a matching score to indicate how well the OpSchema matches the input and - kwargs types. + kwargs types. A function will be evaluated as perfect match, nearest match eligible, + or no match. - There are three types of ONNX overloads in torchlib: + Here are some common examples in categories: - 1. Different types: Caused by the difference between the ONNX spec and PyTorch.The - matching system finds the correct one. + 1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as + the OpSchema. The types of inputs and attributes are exactly the same as the + OpSchema. ```python - @torch_op("aten::mul") - def aten_mul(self: TReal, other: TReal) -> TReal: + inputs = (Tensor[2, 3], Tensor[2, 3]) + attributes = {"alpha": 1.0} + + @torch_op("aten::op") + def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ... - @torch_op("aten::mul") - def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: - ... - ``` + ``` + Result: Perfect match. - 2. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype - would not be None. + 2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However, + the input can't be ignored. None must be provided. ```python - @torch_op("aten::new_full") - def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: - ... + inputs = (Tensor([2, 3]), None) + attributes = {} - @torch_op("aten::new_full") - def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: + aten_op(X: TTensor, Y: Optional[INT64]): ... ``` + Result: Perfect match. + Real example: `aten::convolution`. - Depends on dtype is provided or not, matching system will dispatch the ATen op to - the correct one. + 3. [NOTE: Different attributes]: If an attribute is provided with value, it's + a must to match the attribute in function signature. + ```python + inputs = (Tensor([2, 3]),) + attributes = {"a":1, "b":2} + + aten_op(X: TTensor, a: int): + ... + ``` + Result: No match. + Real example: `aten::div` vs `aten::div.Tensor_mode`. + + 4. [NOTE: Default attributes]: Default attribute will fill in the value into + inputs/attributes. + ```python + inputs = (Tensor([2, 3]),) + attributes = {} + + aten_op(X: TTensor, a: int = 3): + ... + ``` + Result: Perfect match. + Real example: `aten::clone` + + 5. [NOTE: Ignore attribute with None value]: The attributes with None value + will be ignored in matching. + ```python + inputs = (Tensor([2, 3]),) + attributes = {"a": None} + + aten_op(X: TTensor): + ... + ``` + Result: Perfect match. + + ```python + inputs = (Tensor([2, 3]),) + attributes = {"a": None} + + aten_op(X: TTensor, a: int = 3): + ... + ``` + Result: Nearest match eligible. + + Real example: `aten::div` vs `aten::div.Tensor_mode`. Attributes: onnxfunction: The OnnxFunction. @@ -496,12 +558,15 @@ class _OnnxSchemaChecker: for constraint in self.op_schema.type_constraints } self.attributes = self.op_schema.attributes - self._matching_score: int = 0 + self._matching_score: Optional[int] = None @property - def match_score(self) -> int: + def match_score(self) -> Optional[int]: """The matching score of the OnnxSchemaChecker . + If this remains None, it means the matching score has not been calculated, + and it's not a nearest match candidate. + Returns: The matching score of the OnnxSchemaChecker . """ @@ -522,6 +587,13 @@ class _OnnxSchemaChecker: constraints and the number of inputs matches the number of inputs in the OpSchema. + Checking steps: + 1. The function signature matches the inputs number, and attribute names. + 2. The input/attribute types are all in the type constraints. + + A function should at least pass the first step to be eligible for the + nearest matching. + Args: diagnostic: The diagnostic to use for logging detailed info. args: The input arguments organized in PyTorch inputs way. @@ -540,30 +612,54 @@ class _OnnxSchemaChecker: self.param_schema, args, kwargs, - fill_defaults=False, # NOTE: We don't want to change inputs + fill_defaults=True, # fill defaults for optional arguments to match ) - - # TODO(titaiwang): Currently the functions in torchlib are manully annotated, - # so there are quite a few functions that wrongly annotated or strctly annotated. - # The matching system relax the match while we fix them in the future. - self._record_matching_score(function_inputs, function_attributes) - diagnostic.with_additional_message("### Checking perfect match...\n") diagnostic.with_additional_message( f"{diagnostics.format_argument(self.onnxfunction)}" ) - diagnostic.with_additional_message(f"match score: {self.match_score}\n") - + # NOTE: 1. Check if the input number and attribute names match the + # OpSchema. If it's not, we know the function is not eligible to be a perfect + # match, nor a nearest match. + # We use is_perfect_match to postpone the return value to the end + # of the function, as we want to log all the mismatch info. + is_perfect_match = True if len(function_inputs) != len(self.op_schema.inputs): diagnostic.with_additional_message( f"#### Failed: input number mismatch! \n" f"Actual {len(function_inputs)} vs expected {len(self.op_schema.inputs)}\n" ) + diagnostic.with_additional_message( + "The function is not a nearest match candidate.\n" + ) + is_perfect_match = False + + if set(function_attributes) != set(self.attributes): + diagnostic.with_additional_message( + f"#### Failed: attribute mismatch! \n" + f"Actual {set(function_attributes)} vs\n" + f"expected {set(self.attributes)}\n" + ) + diagnostic.with_additional_message( + "The function is not a nearest match candidate.\n" + ) + is_perfect_match = False + + # If it's already not a perfect match, we can return False directly. Further + # checking is only for the functions that are eligible for nearest match. + if not is_perfect_match: return False + + # NOTE: 2. The dtypes of inputs and attributes should be in the + # type constraints of the OpSchema. If they are not, we know the function is not + # eligible to be a perfect match, but can be a nearest match candidate. for schema_input, torch_input in zip(self.op_schema.inputs, function_inputs): torch_input_compatible_types = _find_onnx_data_type(torch_input) allowed_types = self.type_constraints[schema_input.type_str] - if not allowed_types.intersection(torch_input_compatible_types): + if not allowed_types.intersection(torch_input_compatible_types) and not any( + fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str) + for onnx_type_str in allowed_types + ): # If torch_input_compatible_types isn't in allowed_types # of this input defined in the OpSchema, we know the function # and the input are not compatible @@ -572,18 +668,8 @@ class _OnnxSchemaChecker: f"Actual {torch_input_compatible_types} vs\n" f"expected {allowed_types}\n" ) - return False - # Check attributes keys are the same - if set(function_attributes) != set(self.attributes): - # If the attributes of the OpSchema and the attributes don't match, - # we know the function and the input are not compatible - diagnostic.with_additional_message( - f"#### Failed: attribute mismatch! \n" - f"Actual {set(function_attributes)} vs\n" - f"expected {set(self.attributes)}\n" - ) - return False - # Check attribute dtypes + is_perfect_match = False + for attribute_name, attribute in function_attributes.items(): if not self._match_onnx_attribute_type(attribute_name, attribute): # If the attribute type of the OpSchema and the attribute type don't match, @@ -593,8 +679,12 @@ class _OnnxSchemaChecker: f"Actual {type(attribute)} vs\n" f"expected {self.attributes[attribute_name].type}\n" ) - return False - return True + is_perfect_match = False + + # NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype. + self._record_matching_score(function_inputs, function_attributes) + diagnostic.with_additional_message(f"match score: {self.match_score}\n") + return is_perfect_match @_beartype.beartype def _match_onnx_attribute_type( @@ -632,13 +722,16 @@ class _OnnxSchemaChecker: ): """Calculate the inputs matching score of the OpSchema requirements to find the nearest match. + Only the functions which have the same number of inputs and attributes as the + OpSchema are eligible to be a nearest match candidate. Thus, we don't need to + check the length of inputs and attributes here, and only check the types of + inputs and attributes. + How the matchsing score is calculated: - 1. score += 1 if one input type is in the type constraints. - 2. score -= 1 if one kwarg is not symmetrically the same. + score += 1 if one input/attribute type is in the type constraints. Limitations: - 1. An Overload is punished if it doesn't have `default` attributes. - 2. None/NoeType/[] could result in zero matches, and the same score of overloads, + None/NoeType/[] could result in zero matches, and the same score of overloads, which will be recorded in SARIF. Args: @@ -648,7 +741,7 @@ class _OnnxSchemaChecker: Returns: True if the inputs match the requirements, False otherwise. """ - + self._matching_score = 0 # If they have different length of arguments, the score would be lower to those # functions which have the same length of arguments. for schema_input, torch_input in zip(self.op_schema.inputs, inputs): @@ -661,11 +754,6 @@ class _OnnxSchemaChecker: self._matching_score += 1 # NOTE: The penalty is applied to those functions which have different attributes. for attribute_name, attribute_proto in self.attributes.items(): - if attribute_name not in attributes: - # If the attribute of the OpSchema and the attribute don't match, - # we know the function and the input are not compatible - self._matching_score -= 1 - continue attribute = attributes[attribute_name] attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( type(attribute) @@ -674,10 +762,6 @@ class _OnnxSchemaChecker: # If the attribute type of the OpSchema and the attribute type don't match, # we know the function and the input are not compatible self._matching_score -= 1 - # If there is any unexpected attribute in attributes, we know the function - # and the input are not compatible - extra_attrbute_counts = set(attributes).difference(set(self.attributes)) - self._matching_score -= len(extra_attrbute_counts) # NOTE: Referenced from onnxscript internal function. # Importing this function makes the code less robust, as it is not a public API. @@ -693,6 +777,12 @@ class _OnnxSchemaChecker: ) -> Tuple[List[Any], Dict[str, Any]]: """Separate Python args and kwargs into ONNX inputs and attributes. + Extra_kwargs are ignored if their values are None. For example, if the + OpSchema has an attribute "rounding_mode" and the caller provides + "rounding_mode=None", the attribute "rounding_mode" will not be included + in the returned attributes when the OnnxFunction signature doesn't have + "rounding_mode" as an attribute. + Args: param_schemas: The parameter schemas of an Op or a OnnxFunction. args: The Python positional arguments supplied by the caller. @@ -711,9 +801,13 @@ class _OnnxSchemaChecker: # args, kwargs and param_schemas should be all in order # user may not specify all inputs or attributes + # TODO: avoid circular dependency + import onnx + onnx_inputs: List[Any] = [] onnx_attributes: Dict[str, Any] = dict() - + # NOTE: We need to copy kwargs because we will mutate it + copy_kwargs = kwargs.copy() for i, param in enumerate(param_schemas): if param.is_variadic_input: # Exhaust all remaining args @@ -725,16 +819,32 @@ class _OnnxSchemaChecker: onnx_inputs.append(args[i]) else: onnx_attributes[param.name] = args[i] - elif param.name in kwargs: + elif param.name in copy_kwargs: if param.is_input: - onnx_inputs.append(kwargs[param.name]) + # Move the input from kwargs to inputs + onnx_inputs.append(copy_kwargs[param.name]) + copy_kwargs.pop(param.name) else: - onnx_attributes[param.name] = kwargs[param.name] - elif param.is_attribute and param.default is not object(): + onnx_attributes[param.name] = copy_kwargs[param.name] + elif ( + param.is_attribute + and self.attributes[param.name].default_value.type + != onnx.AttributeProto.UNDEFINED # type: ignore[attr-defined] + ): # User did not provide the attribute if fill_defaults: onnx_attributes[param.name] = param.default + # optional input + elif param.is_input: + # TODO: support optional input default in onnx-script? + if fill_defaults: + onnx_inputs.append(None) + # NOTE: Pick up extra kwargs if it's not None. None is not expected + # as an attribute value in torchlib. + for k, v in copy_kwargs.items(): + if k not in onnx_attributes and v is not None: + onnx_attributes[k] = v return onnx_inputs, onnx_attributes diff --git a/torch/onnx/_internal/fx/op_validation.py b/torch/onnx/_internal/fx/op_validation.py index 80a2f8ae3ef..6e03982971b 100644 --- a/torch/onnx/_internal/fx/op_validation.py +++ b/torch/onnx/_internal/fx/op_validation.py @@ -58,7 +58,6 @@ def validate_op_between_ort_torch( There are three signs can be found: 1. Blue: Pass 2. Yellow: Bypass - 3. Red: Fail Args: node (torch.fx.Node): The validated fx.node @@ -116,7 +115,6 @@ def validate_op_between_ort_torch( symbolic_fn.param_schemas(), torch_args, torch_kwargs, - fill_defaults=False, allow_extra_kwargs=True, ) # NOTE: Apply kwargs preprocessing AFTER they are split @@ -310,16 +308,17 @@ def _convert_torch_args_to_onnxfunction_args( param_schemas: Sequence[onnxscript.values.ParamSchema], args: List[fx_type_utils.Argument], kwargs: Dict[str, fx_type_utils.Argument], - fill_defaults: bool = False, allow_extra_kwargs: bool = False, ) -> Tuple[List[Any], Dict[str, Any],]: """Convert Python args and kwargs to OnnxFunction acceptable with matching ONNX ParamSchema. + NOTE: This is different from the param_schema separating in dispatcher, since at this point + we are already sure that the args and kwargs are in order and matched. + Args: param_schemas: The parameter schemas of an Op or a OnnxFunction. args: The Python positional arguments supplied by the caller. kwargs: The Python keyword arguments supplied by the caller. - fill_defaults: Whether to fill the default values for attributes. allow_extra_kwargs: Whether to allow extra keyword arguments. When set to True, extra/unknown arguments will be ignored. @@ -359,10 +358,6 @@ def _convert_torch_args_to_onnxfunction_args( tagged_kwargs[param.name] = _convert_tensor_to_numpy(kwargs[param.name]) else: tagged_kwargs[param.name] = kwargs[param.name] - elif param.default is not object(): - # User did not provide the input/attribute - if fill_defaults: - tagged_kwargs[param.name] = param.default elif param.required: raise TypeError(f"Required input/attribute '{param}' was not provided") diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index a32e6737d0e..b889d6120c7 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -47,6 +47,10 @@ def from_sym_value_to_torch_dtype(sym_value: SYM_VALUE_TYPE) -> torch.dtype: return _SYM_TYPE_TO_TORCH_DTYPE[type(sym_value)] +def is_optional_onnx_dtype_str(onnx_type_str: str) -> bool: + return onnx_type_str in _OPTIONAL_ONNX_DTYPE_STR + + def from_torch_dtype_to_onnx_dtype_str(dtype: Union[torch.dtype, type]) -> Set[str]: return _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS[dtype] @@ -113,6 +117,12 @@ _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: Dict[ torch.complex128: {"tensor(double)"}, } +_OPTIONAL_ONNX_DTYPE_STR: Set[str] = { + f"optional({value})" + for value_set in _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS.values() + for value in value_set +} + _PYTHON_TYPE_TO_TORCH_DTYPE = { bool: torch.bool, int: torch.int64,