diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 708d6ba187..1e5a36fb9e 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -30,7 +30,6 @@ - empty size={ _size() } size=({_size()}) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 3763e0758c..d489a59c4b 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4062,7 +4062,9 @@ static void ReassignSubgraphDependentNodeArgs(const InlinedHashMapExists()) { auto hit = name_to_nodearg.find(input_def->Name()); if (hit != name_to_nodearg.cend()) { - input_def = hit->second; + // Make sure we create a local to this subgraph definition + const auto* new_name_arg = hit->second; + input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto()); } } } @@ -4088,7 +4090,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin Graph& graph_to_inline = *sub_graph; - std::string unique_id{if_node.Name()}; + std::string unique_id{"_if_"}; if (condition_value) { unique_id.append(then_branch); } else { @@ -4107,7 +4109,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin // Reason: there are no explicit inputs to the subgraphs, and the subgraph's // implicit inputs must be covered by the implicit inputs of the If node. InlinedHashMap outer_scope_values; - const auto if_implicit_inputs = if_node.MutableImplicitInputDefs(); + const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs(); outer_scope_values.reserve(if_implicit_inputs.size()); for (auto* input : if_implicit_inputs) { @@ -4121,8 +4123,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin // We are going to map the outputs of the graph to inline to the outputs of the If node. // They are assumed to be in the same order. - const auto node_output_defs = if_node.MutableOutputDefs(); - const auto graph_output_defs = graph_to_inline.GetOutputs(); + const auto& node_output_defs = if_node.MutableOutputDefs(); + const auto& graph_output_defs = graph_to_inline.GetOutputs(); for (size_t i = 0; i < graph_output_defs.size(); ++i) { name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]); } @@ -4206,6 +4208,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin } } + auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr); // We want to make sure we get nodes in topological order // because Constant folding may cause the nodes appear in // a different order. @@ -4216,68 +4219,94 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin auto* node = graph_to_inline.GetNode(node_idx); assert(node->OpType() != kConstant); - InlinedVector new_node_input_defs; - for (const auto* input_def : node->InputDefs()) { + // Inputs + // Chop off trailing non-existing defs, but preserve non-existing in the middle + auto& input_defs = node->MutableInputDefs(); + auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + input_defs.resize(std::distance(input_defs.begin(), last_existing.base())); + + InlinedVector new_input_defs; + for (auto* input_def : node->InputDefs()) { if (input_def->Exists()) { // Check if this is one of the implicit graph inputs - // then leave the name as is and re-use the NodeArg + // then re-assign the def to the outer scope value. const auto& input_name = input_def->Name(); auto outer_hit = outer_scope_values.find(input_name); if (outer_hit != outer_scope_values.cend()) { - new_node_input_defs.push_back(outer_hit->second); + // get/create local definition + NodeArg* outer_arg = outer_hit->second; + auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto()); + new_input_defs.push_back(&this_scope_arg); } else { auto hit = name_to_nodearg.find(input_name); if (hit != name_to_nodearg.cend()) { - // This is other node output, constant node or initializer that was renamed. - new_node_input_defs.push_back(hit->second); + // This is other node output in the dest graph, + // constant node or initializer that was renamed. + new_input_defs.push_back(hit->second); } else { ORT_THROW("Node's: ", node->Name(), " input: ", input_name, " is not If node's input or previous node output in this subgraph"); } } + } else { + new_input_defs.push_back(non_existing_arg); } } - InlinedVector new_node_output_defs; - for (const auto* output_def : node->OutputDefs()) { - const auto& output_name = output_def->Name(); - auto hit = name_to_nodearg.find(output_name); - if (hit != name_to_nodearg.cend()) { - // This is one of the graph outputs, we rename it to - // If node output. - new_node_output_defs.push_back(hit->second); + // Outputs + // Chop off trailing non-existing defs + auto& output_defs = node->MutableOutputDefs(); + last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + output_defs.resize(std::distance(output_defs.begin(), last_existing.base())); + + InlinedVector new_output_defs; + for (auto* output_def : node->OutputDefs()) { + if (output_def->Exists()) { + const auto& output_name = output_def->Name(); + auto hit = name_to_nodearg.find(output_name); + if (hit != name_to_nodearg.cend()) { + // This is one of the If node outputs, simply reassign the def. + // If node defs are already in the destination graph + new_output_defs.push_back(hit->second); + } else { + // We generate an output to downstream nodes. + auto new_name = GenerateNodeArgName(make_unique(output_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); + new_output_defs.push_back(&new_arg); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + } } else { - // We generate an output to downstream nodes. - auto new_name = GenerateNodeArgName(make_unique(output_name)); - NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); - new_node_output_defs.push_back(&new_arg); - ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + new_output_defs.push_back(non_existing_arg); } } const auto new_node_name = GenerateNodeName(make_unique(node->OpType())); Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(), - new_node_input_defs, - new_node_output_defs, + new_input_defs, + new_output_defs, nullptr, node->Domain()); + new_node.SetSinceVersion(node->SinceVersion()); + new_node.op_ = node->op_; + if (!is_this_main_graph) { map_defs(new_node, input_args, true); map_defs(new_node, output_args, false); new_nodes.push_back(&new_node); } - new_node.SetSinceVersion(node->SinceVersion()); - new_node.op_ = node->op_; - if (node->ContainsSubgraph()) { auto& subgraphs = node->MutableSubgraphs(); // Check if any of this node implicit inputs of this graph is in the renaming map + // that would mean they come from the destination graph, not from the parent + // of the destination graph. int renames_subgraph_names = 0; - auto& new_implicit_defs = node->MutableImplicitInputDefs(); - for (auto& input_def : new_implicit_defs) { + auto& implicit_defs = node->MutableImplicitInputDefs(); + for (auto& input_def : implicit_defs) { auto hit = name_to_nodearg.find(input_def->Name()); if (hit != name_to_nodearg.cend()) { input_def = hit->second; @@ -4298,7 +4327,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin new_node.MutableSubgraphs() = std::move(subgraphs); new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph()); - new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs); + new_node.MutableImplicitInputDefs() = std::move(implicit_defs); } new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes()); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 17b26ed7ca..ef6e2d531b 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1176,6 +1176,162 @@ TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges) ASSERT_EQ(op_to_count["Cast"], 2); } +TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddleArgNonExisting) { + // This model has a Resize() call with a middle argument non-existing. + // We want to make sure that the input edges for that Resize() node + // are properly rebuilt with a middle argument non-existing + // during If constant folding + // This test is only valid if Resize() node resides in the nested subgraph which gets inlined + // however, the destination graph must not be the main graph. Then we test that the edges are rebuild + // properly. Also Resize() should not be the first node in the resulting subgraph, so it has edges + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[128] x, float[128] x1) => (float[N] y) + { + y = local.aten_gather (x, x1) + } + < + opset_import: [ "" : 16, "local" : 1], + domain: "local" + > + aten_gather (self, index) => (result_16) + { + resize_scales = Constant () + tmp_0 = Size (index) + int64_0 = Constant () + int64_0_cast = CastLike (int64_0, tmp_0) + cond = Equal (tmp_0, int64_0_cast) + result_16 = If (cond) ( result) { + result = Identity (self) + }, else_branch: graph = elseGraph_10 () => ( result_15) { + tmp_1 = Shape (self) + tmp_2 = Size (tmp_1) + int64_0_3 = Constant () + int64_0_3_cast = CastLike (int64_0_3, tmp_2) + cond_4 = Equal (tmp_2, int64_0_3_cast) + self_8 = If (cond_4) ( self_6) { + tmp_5 = Constant () + self_6 = Reshape (self, tmp_5) + }, else_branch: graph = elseGraph_13 () => ( self_7) { + self_71 = Mul(self, self) + float_size = CastLike (tmp_0, resize_scales) + non_constant_resize_scales = Mul(float_size, resize_scales) + self_7 = Resize(self_71,, non_constant_resize_scales) + }> + tmp_9 = Size (index) + int64_0_10 = Constant () + int64_0_10_cast = CastLike (int64_0_10, tmp_9) + cond_11 = Equal (tmp_9, int64_0_10_cast) + result_15 = If (cond_11) ( result_12) { + result_12 = CastLike (index, self_8) + }, else_branch: graph = elseGraph_15 () => ( result_14) { + index_13 = Cast (index) + result_14 = GatherElements (self_8, index_13) + }> + }> + } + )"; + + /** Optimized model graph + < + ir_version: 8, + opset_import: ["" : 16, + "local" : 1, + "com.microsoft.nchwc" : 1, + "ai.onnx.ml" : 4, + "ai.onnx.training" : 1, + "ai.onnx.preview.training" : 1, + "com.microsoft" : 1, + "com.microsoft.experimental" : 1, "org.pytorch.aten" : 1] + > + agraph (float[128] x, float[128] x1) => (float[128] y) + + { + _inlfunc_aten_gather_tmp_0 = Size (x1) + _inlfunc_aten_gather_cond = Equal (_inlfunc_aten_gather_tmp_0, ortshared_7_0_1_0_token_8) + y = If (_inlfunc_aten_gather_cond) + (float[128] _inlfunc_aten_gather_result) { + _inlfunc_aten_gather_result = Identity (x) + }, else_branch: graph = elseGraph_10 () => (float[128] _inlfunc_aten_gather_result_15) + + { + _if_else_branch__inlfunc_aten_gather_self_71 = Mul (x, x) + _if_else_branch__inlfunc_aten_gather_float_size = Cast (_inlfunc_aten_gather_tmp_0) + _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales = Mul ( + _if_else_branch__inlfunc_aten_gather_float_size, _inlfunc_aten_gather_resize_scales) + _inlfunc_aten_gather_self_8 = Resize ( + _if_else_branch__inlfunc_aten_gather_self_71, , + _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales) + _inlfunc_aten_gather_tmp_9 = Size (x1) + _inlfunc_aten_gather_cond_11 = Equal (_inlfunc_aten_gather_tmp_9, _inlfunc_aten_gather_int64_0_10) + _inlfunc_aten_gather_result_15 = If (_inlfunc_aten_gather_cond_11) + (float[128] _inlfunc_aten_gather_result_12) { + _inlfunc_aten_gather_result_12 = Cast (x1) + }, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) { + _inlfunc_aten_gather_index_13 = Cast (x1) + _inlfunc_aten_gather_result_14 = GatherElements ( + _inlfunc_aten_gather_self_8, _inlfunc_aten_gather_index_13) + }> + }> + } + + */ + + ONNX_NAMESPACE::OnnxParser parser(code); + ONNX_NAMESPACE::ModelProto model_proto; + auto parse_status = parser.Parse(model_proto); + ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); + ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; + + std::string serialized_model; + const bool serialization_status = model_proto.SerializeToString(&serialized_model); + ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string"; + + // AOT inlining is necessary in this case, so the If nodes within the function + // are brought out to the outer scope. So we load this into a session object. + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Let's verify the correctness of the rebuild edges in the Resize node that still + // resides within an if else subgraph. + auto& graph = session_object.GetModel().MainGraph(); + auto op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["If"], 2); + ASSERT_EQ(op_to_count["Resize"], 1); + + auto if_node = std::find_if(graph.Nodes().begin(), graph.Nodes().end(), + [](const auto& node) { return node.OpType() == "If"; }); + ASSERT_NE(graph.Nodes().cend(), if_node); + // Resize is in the else branch + auto subgraph_map = if_node->GetAttributeNameToSubgraphMap(); + auto branch = subgraph_map.find("else_branch"); + ASSERT_NE(subgraph_map.cend(), branch); + + auto resize_node = std::find_if(branch->second->Nodes().begin(), branch->second->Nodes().end(), + [](const auto& node) { return node.OpType() == "Resize"; }); + ASSERT_NE(branch->second->Nodes().cend(), resize_node); + + // Check the edges + ASSERT_EQ(2U, resize_node->GetInputEdgesCount()); + // Should have input edges with arg_pos 0 and 2 + // With 1 is missing + InlinedHashSet dest_edges; + auto zero_edge = resize_node->InputEdgesBegin(); + dest_edges.insert(zero_edge->GetDstArgIndex()); + ++zero_edge; + dest_edges.insert(zero_edge->GetDstArgIndex()); + ASSERT_TRUE(dest_edges.find(0) != dest_edges.end()); + ASSERT_TRUE(dest_edges.find(2) != dest_edges.end()); +} + // Check transformations in the case of a subgraph with constant inputs. TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx";