Handle unused function inputs (#15130)

### Description

Fix issue relating to unused inputs of model-local functions. ORT
creates a schema for all such functions. The creation of this schema
does not handle unused function-inputs. The schema-creation relies on
the use of the function-inputs to infer type-constraints for the input,
and the code ends up creating an erroneous input-descriptor when there
is no use of the function-input.

The fix is to create an input with the given name, with a
type-constraint that allows all types.

### Motivation and Context
Fix https://github.com/microsoft/onnxruntime/issues/15046

Fix https://github.com/microsoft/onnx-script/issues/524

---------

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
This commit is contained in:
G. Ramalingam 2023-03-23 08:12:46 -07:00 committed by GitHub
parent 896ab94780
commit efa1262614
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 1 deletions

View file

@ -229,7 +229,18 @@ static void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_fun
int i = 0;
for (auto& input : input_types_list) {
op_schema->Input(i, input.first, "", input.second);
if (!input.first.empty()) {
op_schema->Input(i, input.first, "", input.second);
} else {
// Handle unused input: its type can be anything.
std::string type_str = "Tin" + std::to_string(i);
op_schema->Input(i, onnx_func_proto.input(i), "", type_str);
auto& dest_types = type_constraint_map[type_str];
dest_types.reserve(dest_types.size() + all_types.size());
for (const auto& data_type : all_types) {
dest_types.emplace_back(data_type);
}
}
++i;
}
i = 0;

View file

@ -372,5 +372,23 @@ TEST(FunctionTest, OuterScopeName) {
Check(code, "x", {1.0, 2.0, 3.0}, "y", {1.0, 2.0, 3.0, 0.0, 0.0, 0.0});
}
// Test use of functions with unused inputs:
TEST(FunctionTest, UnusedFunctionInputs) {
const char* code = R"(
<ir_version: 8, opset_import: ["" : 17, "local" : 1]>
mymodel (float[3] x) => (float[3] y) {
y = local.func (x, x, x)
}
<opset_import: ["" : 17 ], domain: "local">
func (a, b, c) => (y) {
y = Mul (a, b)
}
)";
Check(code, "x", {1.0, 2.0, 3.0}, "y", {1.0, 4.0, 9.0});
}
} // namespace test
} // namespace onnxruntime