From efa12626141db4c249b0bab87b4e0595b8b7172a Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 23 Mar 2023 08:12:46 -0700 Subject: [PATCH] Handle unused function inputs (#15130) ### Description Fix issue relating to unused inputs of model-local functions. ORT creates a schema for all such functions. The creation of this schema does not handle unused function-inputs. The schema-creation relies on the use of the function-inputs to infer type-constraints for the input, and the code ends up creating an erroneous input-descriptor when there is no use of the function-input. The fix is to create an input with the given name, with a type-constraint that allows all types. ### Motivation and Context Fix https://github.com/microsoft/onnxruntime/issues/15046 Fix https://github.com/microsoft/onnx-script/issues/524 --------- Signed-off-by: Ganesan Ramalingam Co-authored-by: Justin Chu Co-authored-by: Scott McKay --- onnxruntime/core/graph/function_utils.cc | 13 ++++++++++++- onnxruntime/test/framework/function_test.cc | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index a320d7b454..9eaa3097fd 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -229,7 +229,18 @@ static void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_fun int i = 0; for (auto& input : input_types_list) { - op_schema->Input(i, input.first, "", input.second); + if (!input.first.empty()) { + op_schema->Input(i, input.first, "", input.second); + } else { + // Handle unused input: its type can be anything. + std::string type_str = "Tin" + std::to_string(i); + op_schema->Input(i, onnx_func_proto.input(i), "", type_str); + auto& dest_types = type_constraint_map[type_str]; + dest_types.reserve(dest_types.size() + all_types.size()); + for (const auto& data_type : all_types) { + dest_types.emplace_back(data_type); + } + } ++i; } i = 0; diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index c74997930d..21bd301a25 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -372,5 +372,23 @@ TEST(FunctionTest, OuterScopeName) { Check(code, "x", {1.0, 2.0, 3.0}, "y", {1.0, 2.0, 3.0, 0.0, 0.0, 0.0}); } + +// Test use of functions with unused inputs: +TEST(FunctionTest, UnusedFunctionInputs) { + const char* code = R"( + + mymodel (float[3] x) => (float[3] y) { + y = local.func (x, x, x) + } + + + func (a, b, c) => (y) { + y = Mul (a, b) + } + )"; + + Check(code, "x", {1.0, 2.0, 3.0}, "y", {1.0, 4.0, 9.0}); +} + } // namespace test } // namespace onnxruntime