Revert to ignoring optional subgraph inputs (#306)

* Revert to ignoring optional subgraph inputs due to abandoning PR 216. Restores previous behaviour that changed a couple of days ago with the Scan v9 checkin.

* Update to allow either all inputs, or just required inputs to be provided for the subgraph.

* Update IterateSequence to prefer all inputs over required inputs.
This commit is contained in:
Scott McKay 2019-01-16 11:58:19 +10:00 committed by GitHub
parent 6225d5fe1e
commit f678f58750
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 8 deletions

View file

@ -1305,18 +1305,32 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
output_types.clear();
auto& subgraph_inputs = subgraph.GetInputs();
auto num_subgraph_inputs = subgraph_inputs.size();
// the spec says all inputs should be provided for the subgraph so default to that first
auto* subgraph_inputs = &subgraph.GetInputsIncludingInitializers();
auto num_subgraph_inputs = subgraph_inputs->size();
if (num_subgraph_inputs != input_types.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Size mismatch validating subgraph inputs. Got ",
input_types.size(), " inputs but subgraph requires ", subgraph_inputs.size());
// we also allow for just the required inputs to be provided to be user friendly due to ONNX requiring
// initializers to have matching inputs (making them optional inputs that most likely the user doesn't want to
// override).
auto& required_subgraph_inputs = subgraph.GetInputs();
auto num_required_subgraph_inputs = required_subgraph_inputs.size();
if (num_required_subgraph_inputs != input_types.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Size mismatch validating subgraph inputs. Got ", input_types.size(),
" inputs but subgraph has ", num_subgraph_inputs,
" inputs and requires ", num_required_subgraph_inputs,
" inputs. Either provide all subgraph inputs, or just the required inputs.");
} else {
subgraph_inputs = &required_subgraph_inputs;
num_subgraph_inputs = num_required_subgraph_inputs;
}
}
// apply type/shape info to the subgraph's inputs
for (size_t i = 0; i < num_subgraph_inputs; ++i) {
const auto& input_type = *input_types[i];
const auto& subgraph_input = *subgraph_inputs[i];
const auto& subgraph_input = *subgraph_inputs->at(i);
NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name());
status = mutable_nodearg->UpdateTypeAndShape(input_type);

View file

@ -96,10 +96,19 @@ Status IterateSequence(OpKernelContextInternal& context,
std::vector<std::string>& subgraph_output_names,
std::vector<std::unique_ptr<OutputIterator>>& output_iterators) {
Status status = Status::OK();
auto& graph_inputs = subgraph.GetInputsIncludingInitializers();
// prefer matching all inputs to the subgraph as per the Scan spec,
auto* graph_inputs = &subgraph.GetInputsIncludingInitializers();
if (static_cast<size_t>(num_variadic_inputs) < graph_inputs->size()) {
// fallback to just the required inputs.
graph_inputs = &subgraph.GetInputs();
ORT_ENFORCE(static_cast<size_t>(num_variadic_inputs) == graph_inputs->size(),
"Graph::InferAndVerifySubgraphTypes should have already validated that "
"num_variadic_inputs matched the subgraph inputs or required inputs.");
}
NameMLValMap feeds;
std::vector<MLValue> fetches;
feeds.reserve(num_variadic_inputs + implicit_inputs.size());
fetches.resize(num_variadic_outputs);
@ -113,7 +122,7 @@ Status IterateSequence(OpKernelContextInternal& context,
for (; seq_no < seq_length; ++seq_no) {
for (int input = 0; input < num_variadic_inputs; ++input) {
// the ordering of the Scan inputs should match the ordering of the subgraph inputs
auto name = graph_inputs[input]->Name();
auto name = (*graph_inputs)[input]->Name();
if (input < num_loop_state_variables) {
// add loop state variable input

View file

@ -267,6 +267,18 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
graph.SetInputOrder({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in});
graph.SetOutputOrder({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0});
// optional input backed by an initializer to make sure that's handled too.
// we expect that Graph::InferAndVerifySubgraphTypes will be able to ignore the optional input if not provided
{
TensorProto optional_input_tensor;
optional_input_tensor.set_name("optional_float");
optional_input_tensor.add_dims(1);
optional_input_tensor.add_float_data(1.f);
optional_input_tensor.set_data_type(onnx::TensorProto_DataType_FLOAT);
graph.AddInitializedTensor(optional_input_tensor);
}
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());