diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 38b2df4c78..44b283b8ce 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1305,18 +1305,32 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, output_types.clear(); - auto& subgraph_inputs = subgraph.GetInputs(); - auto num_subgraph_inputs = subgraph_inputs.size(); + // the spec says all inputs should be provided for the subgraph so default to that first + auto* subgraph_inputs = &subgraph.GetInputsIncludingInitializers(); + auto num_subgraph_inputs = subgraph_inputs->size(); if (num_subgraph_inputs != input_types.size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Size mismatch validating subgraph inputs. Got ", - input_types.size(), " inputs but subgraph requires ", subgraph_inputs.size()); + // we also allow for just the required inputs to be provided to be user friendly due to ONNX requiring + // initializers to have matching inputs (making them optional inputs that most likely the user doesn't want to + // override). + auto& required_subgraph_inputs = subgraph.GetInputs(); + auto num_required_subgraph_inputs = required_subgraph_inputs.size(); + + if (num_required_subgraph_inputs != input_types.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Size mismatch validating subgraph inputs. Got ", input_types.size(), + " inputs but subgraph has ", num_subgraph_inputs, + " inputs and requires ", num_required_subgraph_inputs, + " inputs. Either provide all subgraph inputs, or just the required inputs."); + } else { + subgraph_inputs = &required_subgraph_inputs; + num_subgraph_inputs = num_required_subgraph_inputs; + } } // apply type/shape info to the subgraph's inputs for (size_t i = 0; i < num_subgraph_inputs; ++i) { const auto& input_type = *input_types[i]; - const auto& subgraph_input = *subgraph_inputs[i]; + const auto& subgraph_input = *subgraph_inputs->at(i); NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name()); status = mutable_nodearg->UpdateTypeAndShape(input_type); diff --git a/onnxruntime/core/providers/cpu/controlflow/scan_utils.cc b/onnxruntime/core/providers/cpu/controlflow/scan_utils.cc index 444c471b42..d341314006 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan_utils.cc +++ b/onnxruntime/core/providers/cpu/controlflow/scan_utils.cc @@ -96,10 +96,19 @@ Status IterateSequence(OpKernelContextInternal& context, std::vector& subgraph_output_names, std::vector>& output_iterators) { Status status = Status::OK(); - auto& graph_inputs = subgraph.GetInputsIncludingInitializers(); + + // prefer matching all inputs to the subgraph as per the Scan spec, + auto* graph_inputs = &subgraph.GetInputsIncludingInitializers(); + if (static_cast(num_variadic_inputs) < graph_inputs->size()) { + // fallback to just the required inputs. + graph_inputs = &subgraph.GetInputs(); + ORT_ENFORCE(static_cast(num_variadic_inputs) == graph_inputs->size(), + "Graph::InferAndVerifySubgraphTypes should have already validated that " + "num_variadic_inputs matched the subgraph inputs or required inputs."); + } + NameMLValMap feeds; std::vector fetches; - feeds.reserve(num_variadic_inputs + implicit_inputs.size()); fetches.resize(num_variadic_outputs); @@ -113,7 +122,7 @@ Status IterateSequence(OpKernelContextInternal& context, for (; seq_no < seq_length; ++seq_no) { for (int input = 0; input < num_variadic_inputs; ++input) { // the ordering of the Scan inputs should match the ordering of the subgraph inputs - auto name = graph_inputs[input]->Name(); + auto name = (*graph_inputs)[input]->Name(); if (input < num_loop_state_variables) { // add loop state variable input diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index ad13abc936..98e2ddf98b 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -267,6 +267,18 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options graph.SetInputOrder({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in}); graph.SetOutputOrder({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0}); + // optional input backed by an initializer to make sure that's handled too. + // we expect that Graph::InferAndVerifySubgraphTypes will be able to ignore the optional input if not provided + { + TensorProto optional_input_tensor; + optional_input_tensor.set_name("optional_float"); + optional_input_tensor.add_dims(1); + optional_input_tensor.add_float_data(1.f); + optional_input_tensor.set_data_type(onnx::TensorProto_DataType_FLOAT); + + graph.AddInitializedTensor(optional_input_tensor); + } + auto status = graph.Resolve(); EXPECT_EQ(status, Status::OK());