mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Revert to ignoring optional subgraph inputs (#306)
* Revert to ignoring optional subgraph inputs due to abandoning PR 216. Restores previous behaviour that changed a couple of days ago with the Scan v9 checkin. * Update to allow either all inputs, or just required inputs to be provided for the subgraph. * Update IterateSequence to prefer all inputs over required inputs.
This commit is contained in:
parent
6225d5fe1e
commit
f678f58750
3 changed files with 43 additions and 8 deletions
|
|
@ -1305,18 +1305,32 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
|
|||
|
||||
output_types.clear();
|
||||
|
||||
auto& subgraph_inputs = subgraph.GetInputs();
|
||||
auto num_subgraph_inputs = subgraph_inputs.size();
|
||||
// the spec says all inputs should be provided for the subgraph so default to that first
|
||||
auto* subgraph_inputs = &subgraph.GetInputsIncludingInitializers();
|
||||
auto num_subgraph_inputs = subgraph_inputs->size();
|
||||
|
||||
if (num_subgraph_inputs != input_types.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Size mismatch validating subgraph inputs. Got ",
|
||||
input_types.size(), " inputs but subgraph requires ", subgraph_inputs.size());
|
||||
// we also allow for just the required inputs to be provided to be user friendly due to ONNX requiring
|
||||
// initializers to have matching inputs (making them optional inputs that most likely the user doesn't want to
|
||||
// override).
|
||||
auto& required_subgraph_inputs = subgraph.GetInputs();
|
||||
auto num_required_subgraph_inputs = required_subgraph_inputs.size();
|
||||
|
||||
if (num_required_subgraph_inputs != input_types.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Size mismatch validating subgraph inputs. Got ", input_types.size(),
|
||||
" inputs but subgraph has ", num_subgraph_inputs,
|
||||
" inputs and requires ", num_required_subgraph_inputs,
|
||||
" inputs. Either provide all subgraph inputs, or just the required inputs.");
|
||||
} else {
|
||||
subgraph_inputs = &required_subgraph_inputs;
|
||||
num_subgraph_inputs = num_required_subgraph_inputs;
|
||||
}
|
||||
}
|
||||
|
||||
// apply type/shape info to the subgraph's inputs
|
||||
for (size_t i = 0; i < num_subgraph_inputs; ++i) {
|
||||
const auto& input_type = *input_types[i];
|
||||
const auto& subgraph_input = *subgraph_inputs[i];
|
||||
const auto& subgraph_input = *subgraph_inputs->at(i);
|
||||
|
||||
NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name());
|
||||
status = mutable_nodearg->UpdateTypeAndShape(input_type);
|
||||
|
|
|
|||
|
|
@ -96,10 +96,19 @@ Status IterateSequence(OpKernelContextInternal& context,
|
|||
std::vector<std::string>& subgraph_output_names,
|
||||
std::vector<std::unique_ptr<OutputIterator>>& output_iterators) {
|
||||
Status status = Status::OK();
|
||||
auto& graph_inputs = subgraph.GetInputsIncludingInitializers();
|
||||
|
||||
// prefer matching all inputs to the subgraph as per the Scan spec,
|
||||
auto* graph_inputs = &subgraph.GetInputsIncludingInitializers();
|
||||
if (static_cast<size_t>(num_variadic_inputs) < graph_inputs->size()) {
|
||||
// fallback to just the required inputs.
|
||||
graph_inputs = &subgraph.GetInputs();
|
||||
ORT_ENFORCE(static_cast<size_t>(num_variadic_inputs) == graph_inputs->size(),
|
||||
"Graph::InferAndVerifySubgraphTypes should have already validated that "
|
||||
"num_variadic_inputs matched the subgraph inputs or required inputs.");
|
||||
}
|
||||
|
||||
NameMLValMap feeds;
|
||||
std::vector<MLValue> fetches;
|
||||
|
||||
feeds.reserve(num_variadic_inputs + implicit_inputs.size());
|
||||
fetches.resize(num_variadic_outputs);
|
||||
|
||||
|
|
@ -113,7 +122,7 @@ Status IterateSequence(OpKernelContextInternal& context,
|
|||
for (; seq_no < seq_length; ++seq_no) {
|
||||
for (int input = 0; input < num_variadic_inputs; ++input) {
|
||||
// the ordering of the Scan inputs should match the ordering of the subgraph inputs
|
||||
auto name = graph_inputs[input]->Name();
|
||||
auto name = (*graph_inputs)[input]->Name();
|
||||
|
||||
if (input < num_loop_state_variables) {
|
||||
// add loop state variable input
|
||||
|
|
|
|||
|
|
@ -267,6 +267,18 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
|
|||
graph.SetInputOrder({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in});
|
||||
graph.SetOutputOrder({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0});
|
||||
|
||||
// optional input backed by an initializer to make sure that's handled too.
|
||||
// we expect that Graph::InferAndVerifySubgraphTypes will be able to ignore the optional input if not provided
|
||||
{
|
||||
TensorProto optional_input_tensor;
|
||||
optional_input_tensor.set_name("optional_float");
|
||||
optional_input_tensor.add_dims(1);
|
||||
optional_input_tensor.add_float_data(1.f);
|
||||
optional_input_tensor.set_data_type(onnx::TensorProto_DataType_FLOAT);
|
||||
|
||||
graph.AddInitializedTensor(optional_input_tensor);
|
||||
}
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue