From ca1bbff5d4c2b995cf8fb16d6f57757653735352 Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Fri, 17 Apr 2020 19:33:34 -0700 Subject: [PATCH] subgraph type override handling and unit test (#3560) * unit test for subgraph type override * unit test - re-wire input properly to subgraph * update args Co-authored-by: Ethan Tao --- include/onnxruntime/core/graph/graph.h | 3 +- include/onnxruntime/core/graph/node_arg.h | 8 +- onnxruntime/core/graph/graph.cc | 80 +++++++++---- .../providers/cpu/controlflow/loop_test.cc | 110 ++++++++++++++++++ .../test/providers/provider_test_utils.cc | 12 +- .../test/providers/provider_test_utils.h | 6 +- 6 files changed, 184 insertions(+), 35 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index d7398b4af5..766c6d96b7 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -987,7 +987,8 @@ class Graph { // perform type and shape inferencing on the subgraph and Resolve to validate static common::Status InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, const std::vector& input_types, - std::vector& output_types); + std::vector& output_types, + const Graph::ResolveOptions& options); // Apply type-inference and type-checking to all inputs and initializers: common::Status TypeCheckInputsAndInitializers(); diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index ef161858f6..6a71ecd697 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -67,15 +67,17 @@ class NodeArg { /** Validate and merge type [and shape] info from input_type. @param strict If true, the shape update will fail if there are incompatible values. If false, will be lenient and merge only shape info that can be validly processed. + @param override_types If true, resolve the two inputs or two outputs type when different @returns Success unless there is existing type or shape info that can't be successfully updated. */ - common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger); + common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, const logging::Logger& logger); /** Validate and merge type [and shape] info from node_arg. @param strict If true, the shape update will fail if there are incompatible values. If false, will be lenient and merge only shape info that can be validly processed. + @param override_types If true, resolve the two inputs or two outputs type when different @returns Success unless there is existing type or shape info that can't be successfully updated. */ - common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger); - + common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, const logging::Logger& logger); + /** Gets this NodeArg as a ValueInfoProto. */ const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index f5c7427605..d3f0e31ff6 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -208,7 +208,8 @@ void NodeArg::ClearShape() { } } -common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger) { +common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, + const logging::Logger& logger) { if (!utils::HasType(node_arg_info_)) { *node_arg_info_.mutable_type() = input_type; type_ = DataTypeUtils::ToType(node_arg_info_.type()); @@ -229,10 +230,24 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu const auto& input_tensor_elem_type = input_tensor_type.elem_type(); const auto& current_tensor_elem_type = current_type.tensor_type().elem_type(); - if (input_tensor_elem_type != current_tensor_elem_type) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor element type mismatch. ", - static_cast(input_tensor_elem_type), " != ", - static_cast(current_tensor_elem_type)); + if (input_tensor_elem_type != current_tensor_elem_type) { + if (override_types) { + DataType inferred_type = DataTypeUtils::ToType(input_type); + // The "SetType" call will override the shape information to empty. + // If the original tensor has shape information, need to set it back. + if (Shape()) { + auto old_shape = *Shape(); + SetType(inferred_type); + SetShape(old_shape); + } else { + SetType(inferred_type); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor element type mismatch. ", + static_cast(input_tensor_elem_type), " != ", + static_cast(current_tensor_elem_type)); + } + } if (utils::HasShape(input_tensor_type)) { auto& current_tensor_type = *current_type.mutable_tensor_type(); @@ -249,11 +264,24 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu const auto& input_tensor_type = input_type.sparse_tensor_type(); const auto input_tensor_elem_type = input_tensor_type.elem_type(); const auto current_tensor_elem_type = current_type.sparse_tensor_type().elem_type(); + if (input_tensor_elem_type != current_tensor_elem_type) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseTensor element type mismatch. ", - static_cast(input_tensor_elem_type), " != ", - static_cast(current_tensor_elem_type)); + if (override_types) { + DataType inferred_type = DataTypeUtils::ToType(input_type); + if (Shape()) { + auto old_shape = *Shape(); + SetType(inferred_type); + SetShape(old_shape); + } else { + SetType(inferred_type); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseTensor element type mismatch. ", + static_cast(input_tensor_elem_type), " != ", + static_cast(current_tensor_elem_type)); + } } + if (utils::HasShape(input_tensor_type)) { auto& current_tensor_type = *current_type.mutable_sparse_tensor_type(); if (utils::HasShape(current_tensor_type)) { @@ -275,11 +303,12 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu return Status::OK(); } -common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger) { +common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, + const logging::Logger& logger) { auto status = Status::OK(); if (utils::HasType(node_arg.node_arg_info_)) - status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, logger); + status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, override_types, logger); return status; } @@ -771,7 +800,7 @@ Graph::Graph(const Model& owning_model, // so we prefer the shape from the initializer name_to_type_map[tensor.name()] = t; if (matching_graph_input != nullptr) { - ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, logger)); + ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, false, logger)); } } else { // v4 and later allows a constant initializer with no matching graph input. create a NodeArg for these. @@ -1398,12 +1427,12 @@ bool FullyDefinedType(const TypeProto& type_proto) { // parameters are the Graph instance for the subgraph, the input types from the control flow node that contains // the subgraph, and the vector to write the output from the inferencing. using SubgraphInferencingFunc = - std::function&, std::vector&)>; + std::function&, std::vector&, const Graph::ResolveOptions&)>; class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer { public: - GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func) - : node_(node), graph_(graph), inferencing_func_(inferencing_func) { + GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func, const Graph::ResolveOptions& options) + : node_(node), graph_(graph), inferencing_func_(inferencing_func), options_(options) { } // Perform inferencing on the graph contained in GraphInferencer. @@ -1413,7 +1442,7 @@ class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer { const std::vector& /*input_data*/) override { std::vector output_types; - auto status = inferencing_func_(node_, graph_, input_types, output_types); + auto status = inferencing_func_(node_, graph_, input_types, output_types, options_); if (status != Status::OK()) { fail_type_inference("Graph attribute inferencing failed: ", status.ErrorMessage()); @@ -1426,6 +1455,7 @@ class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer { const Node& node_; Graph& graph_; SubgraphInferencingFunc& inferencing_func_; + const Graph::ResolveOptions& options_; }; // An implementation of the InferenceContext interface required by operator-specific @@ -1436,10 +1466,12 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { public: InferenceContextImpl(Node& node, SubgraphInferencingFunc subgraph_inferencing_func, - const Graph& graph) noexcept + const Graph& graph, + const Graph::ResolveOptions& options) noexcept : node_(node), subgraph_inferencing_func_(subgraph_inferencing_func), - graph_(graph) { + graph_(graph), + options_(options) { node_output_types_.resize(node.OutputDefs().size()); } @@ -1500,7 +1532,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { auto* subgraph = node_.GetMutableGraphAttribute(attribute_name); if (subgraph) { - auto inferencer = onnxruntime::make_unique(node_, *subgraph, subgraph_inferencing_func_); + auto inferencer = onnxruntime::make_unique(node_, *subgraph, subgraph_inferencing_func_, options_); graph_inferencer = inferencer.get(); graph_inferencers_.push_back(std::move(inferencer)); } else { @@ -1518,11 +1550,13 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { SubgraphInferencingFunc subgraph_inferencing_func_; std::vector> graph_inferencers_; const Graph& graph_; + const Graph::ResolveOptions& options_; }; Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, const std::vector& input_types, - std::vector& output_types) { + std::vector& output_types, + const Graph::ResolveOptions& options) { auto status = Status::OK(); output_types.clear(); @@ -1555,7 +1589,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, const auto& subgraph_input = *subgraph_inputs->at(i); NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name()); - status = mutable_nodearg->UpdateTypeAndShape(input_type, true, subgraph.logger_); + status = mutable_nodearg->UpdateTypeAndShape(input_type, true, options.override_types, subgraph.logger_); if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage()); } @@ -1576,7 +1610,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, if (!subgraph_nodearg) continue; - status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, subgraph.logger_); + status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, options.override_types, subgraph.logger_); if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage()); } @@ -1588,8 +1622,6 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, // now that we have handled the input types, do the type/shape inferencing for the subgraph // to flow the type/shape info through it - // TODO: Handle override-type option correctly for subgraphs. - Graph::ResolveOptions options; status = subgraph.PerformTypeAndShapeInferencing(options); ORT_RETURN_IF_ERROR(status); @@ -1695,7 +1727,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // Once that completes, the outputs from the node containing the subgraph will be updated, and the final values // returned here. SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes); - InferenceContextImpl context(node, func, *this); + InferenceContextImpl context(node, func, *this, options); try { context.RunInferencing(); diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 0c7b53fef9..5eab21c501 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -574,6 +574,116 @@ TEST(Loop, InfiniteLoopTermination) { terminator_thread.join(); } +// Add basic test to trigger types override logic in Graph::InferAndVerifySubgraphTypes as well as +// type/shape inferencing for subgraph to flow the type/shape info through +// subgraph.PerformTypeAndShapeInferencing(options). +// In this test, main graph has original input/expected output defined as "double" where the subgraph as "float". +// Expectation is types should get propagated properly in subgraph and yield correct output +// +// TODO - when the input/output type in main graph is float16, extra Cast nodes will be added and type input type +// will be changed by InsertCastTransformer for graph execution thus causes type mismatch failure. +// Need to investigate how InsertCastTransformer works in future. +TEST(Loop, SubgraphTypeOverride) { + auto create_subgraph = [](const RunOptions&) { + Model model("Loop subgraph", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + std::vector inputs; + std::vector outputs; + + /* + Inputs: iter_num, cond_in, fake_in, loop carried state variables. + + iter_num_in cond_in fake_in [outer_scope_0] + (unused) | | | + [Identity] [Identity] [Identity] + | | | + cond_out fake_out loop_var_0_out + */ + + // graph inputs types. + TypeProto int64_scalar; + int64_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); + int64_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + TypeProto bool_scalar; + bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL); + bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim(); + + // graph inputs + auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar); + auto& cond_in = graph.GetOrCreateNodeArg("cond_in", &bool_scalar); + auto& fake_in = graph.GetOrCreateNodeArg("fake_in", &float_tensor); + + // outer scope value. need type but not shape. + auto& outer_scope_0 = graph.GetOrCreateNodeArg("outer_scope_0", &float_tensor); + + // add so that we don't end up with it being considered a graph input + graph.AddOuterScopeNodeArg("outer_scope_0"); + + // graph outputs + auto& cond_out = graph.GetOrCreateNodeArg("cond_out", &bool_scalar); + auto& fake_out = graph.GetOrCreateNodeArg("fake_out", &float_tensor); + auto& loop_var_0_out = graph.GetOrCreateNodeArg("loop_var_0_out", &float_tensor); + + // cond_in -> cond_out + { + inputs = {&cond_in}; + outputs = {&cond_out}; + + graph.AddNode("cond_in_identity", "Identity", "Forward cond_in to cond_out", inputs, outputs); + } + + // fake_in -> fake_out + { + inputs = {&fake_in}; + outputs = {&fake_out}; + + graph.AddNode("fake_in_identity", "Identity", "Forward fake_in to fake_out", inputs, outputs); + } + + // outer_scope_0 -> loop_var_0_out + { + inputs = {&outer_scope_0}; + outputs = {&loop_var_0_out}; + + graph.AddNode("loop_var_out", "Identity", "Forward outer_scope_0 to loop_var_0_out", inputs, outputs); + } + + graph.SetInputs({&iter_num_in, &cond_in, &fake_in}); + graph.SetOutputs({&cond_out, &fake_out, &loop_var_0_out}); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + return graph.ToGraphProto(); + }; + + LoopOpTester test{{}, create_subgraph}; + + test.AddInput("M", {1}, {1}); + test.AddInput("cond", {1}, {true}); + test.AddInput("fake", {1}, {0.f}); + test.AddInput("outer_scope_0", {1}, {kOuterNodeAddValue}); + + test.AddOutput("loop_fake_final", {1}, {0.f}); + test.AddOutput("loop_var_0_final", {1, 1}, {kOuterNodeAddValue}); + test.AddOutput("outer_scope_0_out", {1}, {int64_t(kOuterNodeAddValue)}); + + OrtRunOptions session_run_options; + session_run_options.run_tag = "Loop.SubgraphTypeOverride"; + + Graph::ResolveOptions options; + options.override_types = true; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider}, &session_run_options, nullptr, + ExecutionMode::ORT_SEQUENTIAL, {}, options); +} + // Regression test that a subgraph input overrides an outer scope value of the same name. // Replicate issue from https://github.com/onnx/onnx/issues/2082 TEST(Loop, SubgraphInputShadowsOuterScopeValue) { diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index dc23707d5b..a833816f7e 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -626,14 +626,15 @@ void OpTester::Run( const RunOptions* run_options, std::vector>* execution_providers, ExecutionMode execution_mode, - const CustomOutputVerifierFn& custom_output_verifier) { + const CustomOutputVerifierFn& custom_output_verifier, + const Graph::ResolveOptions& options) { SessionOptions so; so.session_logid = op_; so.session_log_verbosity_level = 1; so.execution_mode = execution_mode; so.graph_optimization_level = TransformerLevel::Default; // 'Default' == off Run(so, expect_result, expected_failure_string, excluded_provider_types, - run_options, execution_providers, custom_output_verifier); + run_options, execution_providers, custom_output_verifier, options); } void OpTester::Run( @@ -643,7 +644,8 @@ void OpTester::Run( const std::unordered_set& excluded_provider_types, const RunOptions* run_options, std::vector>* execution_providers, - const CustomOutputVerifierFn& custom_output_verifier) { + const CustomOutputVerifierFn& custom_output_verifier, + const Graph::ResolveOptions& options) { std::string cur_provider = "not set"; try { #ifndef NDEBUG @@ -660,12 +662,12 @@ void OpTester::Run( expect_result == ExpectResult::kExpectFailure) { // capture possible exceptions from shape inference for invalid testcase try { - status = graph.Resolve(); + status = graph.Resolve(options); } catch (const std::exception& ex) { status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what()); } } else { - status = graph.Resolve(); + status = graph.Resolve(options); } if (!status.IsOK()) { diff --git a/onnxruntime/test/providers/provider_test_utils.h b/onnxruntime/test/providers/provider_test_utils.h index d94aeed7c5..0f393a8b26 100644 --- a/onnxruntime/test/providers/provider_test_utils.h +++ b/onnxruntime/test/providers/provider_test_utils.h @@ -413,7 +413,8 @@ class OpTester { const RunOptions* run_options = nullptr, std::vector>* execution_providers = nullptr, ExecutionMode execution_mode = ExecutionMode::ORT_SEQUENTIAL, - const CustomOutputVerifierFn& custom_output_verifier = {}); + const CustomOutputVerifierFn& custom_output_verifier = {}, + const Graph::ResolveOptions& resolve_options = {}); void Run(SessionOptions session_options, ExpectResult expect_result = ExpectResult::kExpectSuccess, @@ -421,7 +422,8 @@ class OpTester { const std::unordered_set& excluded_provider_types = {}, const RunOptions* run_options = nullptr, std::vector>* execution_providers = nullptr, - const CustomOutputVerifierFn& custom_output_verifier = {}); + const CustomOutputVerifierFn& custom_output_verifier = {}, + const Graph::ResolveOptions& resolve_options = {}); std::vector GetFetches() { return fetches_; }