From cc542024ce3bd94dfaaabd6100c281cfc4bd2595 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 20 Nov 2023 14:49:09 -0800 Subject: [PATCH] Create edges with arg positons correctly accounting for non-existing args (#18462) ### Description Truncate traling non-existing arguments. Make sure we do not skip on the non-existing arguments in the middle, because shape inferece relies on their proper position. This also affects the argument position in the Edges that must be properly rebuilt each time If node branch is inlined. Make sure that when we rename Defs in subgraphs, new renamed defs are created in those subgraphs instead of pointing to outer scope defs. Add unit test. ### Motivation and Context This is a follow up for https://github.com/microsoft/onnxruntime/pull/18105 Currently, the non-trailing arguments are simply ignored and the edges are created with potentially incorrect positions. --- cmake/external/abseil-cpp.natvis | 1 - onnxruntime/core/graph/graph.cc | 93 +++++++---- .../test/optimizer/graph_transform_test.cc | 156 ++++++++++++++++++ 3 files changed, 217 insertions(+), 33 deletions(-) diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 708d6ba187..1e5a36fb9e 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -30,7 +30,6 @@ - empty size={ _size() } size=({_size()}) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 3763e0758c..d489a59c4b 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4062,7 +4062,9 @@ static void ReassignSubgraphDependentNodeArgs(const InlinedHashMapExists()) { auto hit = name_to_nodearg.find(input_def->Name()); if (hit != name_to_nodearg.cend()) { - input_def = hit->second; + // Make sure we create a local to this subgraph definition + const auto* new_name_arg = hit->second; + input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto()); } } } @@ -4088,7 +4090,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin Graph& graph_to_inline = *sub_graph; - std::string unique_id{if_node.Name()}; + std::string unique_id{"_if_"}; if (condition_value) { unique_id.append(then_branch); } else { @@ -4107,7 +4109,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin // Reason: there are no explicit inputs to the subgraphs, and the subgraph's // implicit inputs must be covered by the implicit inputs of the If node. InlinedHashMap outer_scope_values; - const auto if_implicit_inputs = if_node.MutableImplicitInputDefs(); + const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs(); outer_scope_values.reserve(if_implicit_inputs.size()); for (auto* input : if_implicit_inputs) { @@ -4121,8 +4123,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin // We are going to map the outputs of the graph to inline to the outputs of the If node. // They are assumed to be in the same order. - const auto node_output_defs = if_node.MutableOutputDefs(); - const auto graph_output_defs = graph_to_inline.GetOutputs(); + const auto& node_output_defs = if_node.MutableOutputDefs(); + const auto& graph_output_defs = graph_to_inline.GetOutputs(); for (size_t i = 0; i < graph_output_defs.size(); ++i) { name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]); } @@ -4206,6 +4208,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin } } + auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr); // We want to make sure we get nodes in topological order // because Constant folding may cause the nodes appear in // a different order. @@ -4216,68 +4219,94 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin auto* node = graph_to_inline.GetNode(node_idx); assert(node->OpType() != kConstant); - InlinedVector new_node_input_defs; - for (const auto* input_def : node->InputDefs()) { + // Inputs + // Chop off trailing non-existing defs, but preserve non-existing in the middle + auto& input_defs = node->MutableInputDefs(); + auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + input_defs.resize(std::distance(input_defs.begin(), last_existing.base())); + + InlinedVector new_input_defs; + for (auto* input_def : node->InputDefs()) { if (input_def->Exists()) { // Check if this is one of the implicit graph inputs - // then leave the name as is and re-use the NodeArg + // then re-assign the def to the outer scope value. const auto& input_name = input_def->Name(); auto outer_hit = outer_scope_values.find(input_name); if (outer_hit != outer_scope_values.cend()) { - new_node_input_defs.push_back(outer_hit->second); + // get/create local definition + NodeArg* outer_arg = outer_hit->second; + auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto()); + new_input_defs.push_back(&this_scope_arg); } else { auto hit = name_to_nodearg.find(input_name); if (hit != name_to_nodearg.cend()) { - // This is other node output, constant node or initializer that was renamed. - new_node_input_defs.push_back(hit->second); + // This is other node output in the dest graph, + // constant node or initializer that was renamed. + new_input_defs.push_back(hit->second); } else { ORT_THROW("Node's: ", node->Name(), " input: ", input_name, " is not If node's input or previous node output in this subgraph"); } } + } else { + new_input_defs.push_back(non_existing_arg); } } - InlinedVector new_node_output_defs; - for (const auto* output_def : node->OutputDefs()) { - const auto& output_name = output_def->Name(); - auto hit = name_to_nodearg.find(output_name); - if (hit != name_to_nodearg.cend()) { - // This is one of the graph outputs, we rename it to - // If node output. - new_node_output_defs.push_back(hit->second); + // Outputs + // Chop off trailing non-existing defs + auto& output_defs = node->MutableOutputDefs(); + last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + output_defs.resize(std::distance(output_defs.begin(), last_existing.base())); + + InlinedVector new_output_defs; + for (auto* output_def : node->OutputDefs()) { + if (output_def->Exists()) { + const auto& output_name = output_def->Name(); + auto hit = name_to_nodearg.find(output_name); + if (hit != name_to_nodearg.cend()) { + // This is one of the If node outputs, simply reassign the def. + // If node defs are already in the destination graph + new_output_defs.push_back(hit->second); + } else { + // We generate an output to downstream nodes. + auto new_name = GenerateNodeArgName(make_unique(output_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); + new_output_defs.push_back(&new_arg); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + } } else { - // We generate an output to downstream nodes. - auto new_name = GenerateNodeArgName(make_unique(output_name)); - NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); - new_node_output_defs.push_back(&new_arg); - ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + new_output_defs.push_back(non_existing_arg); } } const auto new_node_name = GenerateNodeName(make_unique(node->OpType())); Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(), - new_node_input_defs, - new_node_output_defs, + new_input_defs, + new_output_defs, nullptr, node->Domain()); + new_node.SetSinceVersion(node->SinceVersion()); + new_node.op_ = node->op_; + if (!is_this_main_graph) { map_defs(new_node, input_args, true); map_defs(new_node, output_args, false); new_nodes.push_back(&new_node); } - new_node.SetSinceVersion(node->SinceVersion()); - new_node.op_ = node->op_; - if (node->ContainsSubgraph()) { auto& subgraphs = node->MutableSubgraphs(); // Check if any of this node implicit inputs of this graph is in the renaming map + // that would mean they come from the destination graph, not from the parent + // of the destination graph. int renames_subgraph_names = 0; - auto& new_implicit_defs = node->MutableImplicitInputDefs(); - for (auto& input_def : new_implicit_defs) { + auto& implicit_defs = node->MutableImplicitInputDefs(); + for (auto& input_def : implicit_defs) { auto hit = name_to_nodearg.find(input_def->Name()); if (hit != name_to_nodearg.cend()) { input_def = hit->second; @@ -4298,7 +4327,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin new_node.MutableSubgraphs() = std::move(subgraphs); new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph()); - new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs); + new_node.MutableImplicitInputDefs() = std::move(implicit_defs); } new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes()); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 17b26ed7ca..ef6e2d531b 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1176,6 +1176,162 @@ TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges) ASSERT_EQ(op_to_count["Cast"], 2); } +TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddleArgNonExisting) { + // This model has a Resize() call with a middle argument non-existing. + // We want to make sure that the input edges for that Resize() node + // are properly rebuilt with a middle argument non-existing + // during If constant folding + // This test is only valid if Resize() node resides in the nested subgraph which gets inlined + // however, the destination graph must not be the main graph. Then we test that the edges are rebuild + // properly. Also Resize() should not be the first node in the resulting subgraph, so it has edges + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[128] x, float[128] x1) => (float[N] y) + { + y = local.aten_gather (x, x1) + } + < + opset_import: [ "" : 16, "local" : 1], + domain: "local" + > + aten_gather (self, index) => (result_16) + { + resize_scales = Constant () + tmp_0 = Size (index) + int64_0 = Constant () + int64_0_cast = CastLike (int64_0, tmp_0) + cond = Equal (tmp_0, int64_0_cast) + result_16 = If (cond) ( result) { + result = Identity (self) + }, else_branch: graph = elseGraph_10 () => ( result_15) { + tmp_1 = Shape (self) + tmp_2 = Size (tmp_1) + int64_0_3 = Constant () + int64_0_3_cast = CastLike (int64_0_3, tmp_2) + cond_4 = Equal (tmp_2, int64_0_3_cast) + self_8 = If (cond_4) ( self_6) { + tmp_5 = Constant () + self_6 = Reshape (self, tmp_5) + }, else_branch: graph = elseGraph_13 () => ( self_7) { + self_71 = Mul(self, self) + float_size = CastLike (tmp_0, resize_scales) + non_constant_resize_scales = Mul(float_size, resize_scales) + self_7 = Resize(self_71,, non_constant_resize_scales) + }> + tmp_9 = Size (index) + int64_0_10 = Constant () + int64_0_10_cast = CastLike (int64_0_10, tmp_9) + cond_11 = Equal (tmp_9, int64_0_10_cast) + result_15 = If (cond_11) ( result_12) { + result_12 = CastLike (index, self_8) + }, else_branch: graph = elseGraph_15 () => ( result_14) { + index_13 = Cast (index) + result_14 = GatherElements (self_8, index_13) + }> + }> + } + )"; + + /** Optimized model graph + < + ir_version: 8, + opset_import: ["" : 16, + "local" : 1, + "com.microsoft.nchwc" : 1, + "ai.onnx.ml" : 4, + "ai.onnx.training" : 1, + "ai.onnx.preview.training" : 1, + "com.microsoft" : 1, + "com.microsoft.experimental" : 1, "org.pytorch.aten" : 1] + > + agraph (float[128] x, float[128] x1) => (float[128] y) + + { + _inlfunc_aten_gather_tmp_0 = Size (x1) + _inlfunc_aten_gather_cond = Equal (_inlfunc_aten_gather_tmp_0, ortshared_7_0_1_0_token_8) + y = If (_inlfunc_aten_gather_cond) + (float[128] _inlfunc_aten_gather_result) { + _inlfunc_aten_gather_result = Identity (x) + }, else_branch: graph = elseGraph_10 () => (float[128] _inlfunc_aten_gather_result_15) + + { + _if_else_branch__inlfunc_aten_gather_self_71 = Mul (x, x) + _if_else_branch__inlfunc_aten_gather_float_size = Cast (_inlfunc_aten_gather_tmp_0) + _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales = Mul ( + _if_else_branch__inlfunc_aten_gather_float_size, _inlfunc_aten_gather_resize_scales) + _inlfunc_aten_gather_self_8 = Resize ( + _if_else_branch__inlfunc_aten_gather_self_71, , + _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales) + _inlfunc_aten_gather_tmp_9 = Size (x1) + _inlfunc_aten_gather_cond_11 = Equal (_inlfunc_aten_gather_tmp_9, _inlfunc_aten_gather_int64_0_10) + _inlfunc_aten_gather_result_15 = If (_inlfunc_aten_gather_cond_11) + (float[128] _inlfunc_aten_gather_result_12) { + _inlfunc_aten_gather_result_12 = Cast (x1) + }, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) { + _inlfunc_aten_gather_index_13 = Cast (x1) + _inlfunc_aten_gather_result_14 = GatherElements ( + _inlfunc_aten_gather_self_8, _inlfunc_aten_gather_index_13) + }> + }> + } + + */ + + ONNX_NAMESPACE::OnnxParser parser(code); + ONNX_NAMESPACE::ModelProto model_proto; + auto parse_status = parser.Parse(model_proto); + ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); + ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; + + std::string serialized_model; + const bool serialization_status = model_proto.SerializeToString(&serialized_model); + ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string"; + + // AOT inlining is necessary in this case, so the If nodes within the function + // are brought out to the outer scope. So we load this into a session object. + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Let's verify the correctness of the rebuild edges in the Resize node that still + // resides within an if else subgraph. + auto& graph = session_object.GetModel().MainGraph(); + auto op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["If"], 2); + ASSERT_EQ(op_to_count["Resize"], 1); + + auto if_node = std::find_if(graph.Nodes().begin(), graph.Nodes().end(), + [](const auto& node) { return node.OpType() == "If"; }); + ASSERT_NE(graph.Nodes().cend(), if_node); + // Resize is in the else branch + auto subgraph_map = if_node->GetAttributeNameToSubgraphMap(); + auto branch = subgraph_map.find("else_branch"); + ASSERT_NE(subgraph_map.cend(), branch); + + auto resize_node = std::find_if(branch->second->Nodes().begin(), branch->second->Nodes().end(), + [](const auto& node) { return node.OpType() == "Resize"; }); + ASSERT_NE(branch->second->Nodes().cend(), resize_node); + + // Check the edges + ASSERT_EQ(2U, resize_node->GetInputEdgesCount()); + // Should have input edges with arg_pos 0 and 2 + // With 1 is missing + InlinedHashSet dest_edges; + auto zero_edge = resize_node->InputEdgesBegin(); + dest_edges.insert(zero_edge->GetDstArgIndex()); + ++zero_edge; + dest_edges.insert(zero_edge->GetDstArgIndex()); + ASSERT_TRUE(dest_edges.find(0) != dest_edges.end()); + ASSERT_TRUE(dest_edges.find(2) != dest_edges.end()); +} + // Check transformations in the case of a subgraph with constant inputs. TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx";