diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis
index 708d6ba187..1e5a36fb9e 100644
--- a/cmake/external/abseil-cpp.natvis
+++ b/cmake/external/abseil-cpp.natvis
@@ -30,7 +30,6 @@
- empty
size={ _size() }
size=({_size()})
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 3763e0758c..d489a59c4b 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -4062,7 +4062,9 @@ static void ReassignSubgraphDependentNodeArgs(const InlinedHashMapExists()) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
- input_def = hit->second;
+ // Make sure we create a local to this subgraph definition
+ const auto* new_name_arg = hit->second;
+ input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto());
}
}
}
@@ -4088,7 +4090,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
Graph& graph_to_inline = *sub_graph;
- std::string unique_id{if_node.Name()};
+ std::string unique_id{"_if_"};
if (condition_value) {
unique_id.append(then_branch);
} else {
@@ -4107,7 +4109,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
// Reason: there are no explicit inputs to the subgraphs, and the subgraph's
// implicit inputs must be covered by the implicit inputs of the If node.
InlinedHashMap outer_scope_values;
- const auto if_implicit_inputs = if_node.MutableImplicitInputDefs();
+ const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs();
outer_scope_values.reserve(if_implicit_inputs.size());
for (auto* input : if_implicit_inputs) {
@@ -4121,8 +4123,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
// We are going to map the outputs of the graph to inline to the outputs of the If node.
// They are assumed to be in the same order.
- const auto node_output_defs = if_node.MutableOutputDefs();
- const auto graph_output_defs = graph_to_inline.GetOutputs();
+ const auto& node_output_defs = if_node.MutableOutputDefs();
+ const auto& graph_output_defs = graph_to_inline.GetOutputs();
for (size_t i = 0; i < graph_output_defs.size(); ++i) {
name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]);
}
@@ -4206,6 +4208,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
}
}
+ auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr);
// We want to make sure we get nodes in topological order
// because Constant folding may cause the nodes appear in
// a different order.
@@ -4216,68 +4219,94 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
auto* node = graph_to_inline.GetNode(node_idx);
assert(node->OpType() != kConstant);
- InlinedVector new_node_input_defs;
- for (const auto* input_def : node->InputDefs()) {
+ // Inputs
+ // Chop off trailing non-existing defs, but preserve non-existing in the middle
+ auto& input_defs = node->MutableInputDefs();
+ auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(),
+ [](const NodeArg* node_arg) { return node_arg->Exists(); });
+ input_defs.resize(std::distance(input_defs.begin(), last_existing.base()));
+
+ InlinedVector new_input_defs;
+ for (auto* input_def : node->InputDefs()) {
if (input_def->Exists()) {
// Check if this is one of the implicit graph inputs
- // then leave the name as is and re-use the NodeArg
+ // then re-assign the def to the outer scope value.
const auto& input_name = input_def->Name();
auto outer_hit = outer_scope_values.find(input_name);
if (outer_hit != outer_scope_values.cend()) {
- new_node_input_defs.push_back(outer_hit->second);
+ // get/create local definition
+ NodeArg* outer_arg = outer_hit->second;
+ auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto());
+ new_input_defs.push_back(&this_scope_arg);
} else {
auto hit = name_to_nodearg.find(input_name);
if (hit != name_to_nodearg.cend()) {
- // This is other node output, constant node or initializer that was renamed.
- new_node_input_defs.push_back(hit->second);
+ // This is other node output in the dest graph,
+ // constant node or initializer that was renamed.
+ new_input_defs.push_back(hit->second);
} else {
ORT_THROW("Node's: ", node->Name(), " input: ", input_name,
" is not If node's input or previous node output in this subgraph");
}
}
+ } else {
+ new_input_defs.push_back(non_existing_arg);
}
}
- InlinedVector new_node_output_defs;
- for (const auto* output_def : node->OutputDefs()) {
- const auto& output_name = output_def->Name();
- auto hit = name_to_nodearg.find(output_name);
- if (hit != name_to_nodearg.cend()) {
- // This is one of the graph outputs, we rename it to
- // If node output.
- new_node_output_defs.push_back(hit->second);
+ // Outputs
+ // Chop off trailing non-existing defs
+ auto& output_defs = node->MutableOutputDefs();
+ last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(),
+ [](const NodeArg* node_arg) { return node_arg->Exists(); });
+ output_defs.resize(std::distance(output_defs.begin(), last_existing.base()));
+
+ InlinedVector new_output_defs;
+ for (auto* output_def : node->OutputDefs()) {
+ if (output_def->Exists()) {
+ const auto& output_name = output_def->Name();
+ auto hit = name_to_nodearg.find(output_name);
+ if (hit != name_to_nodearg.cend()) {
+ // This is one of the If node outputs, simply reassign the def.
+ // If node defs are already in the destination graph
+ new_output_defs.push_back(hit->second);
+ } else {
+ // We generate an output to downstream nodes.
+ auto new_name = GenerateNodeArgName(make_unique(output_name));
+ NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
+ new_output_defs.push_back(&new_arg);
+ ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
+ }
} else {
- // We generate an output to downstream nodes.
- auto new_name = GenerateNodeArgName(make_unique(output_name));
- NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
- new_node_output_defs.push_back(&new_arg);
- ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
+ new_output_defs.push_back(non_existing_arg);
}
}
const auto new_node_name = GenerateNodeName(make_unique(node->OpType()));
Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(),
- new_node_input_defs,
- new_node_output_defs,
+ new_input_defs,
+ new_output_defs,
nullptr,
node->Domain());
+ new_node.SetSinceVersion(node->SinceVersion());
+ new_node.op_ = node->op_;
+
if (!is_this_main_graph) {
map_defs(new_node, input_args, true);
map_defs(new_node, output_args, false);
new_nodes.push_back(&new_node);
}
- new_node.SetSinceVersion(node->SinceVersion());
- new_node.op_ = node->op_;
-
if (node->ContainsSubgraph()) {
auto& subgraphs = node->MutableSubgraphs();
// Check if any of this node implicit inputs of this graph is in the renaming map
+ // that would mean they come from the destination graph, not from the parent
+ // of the destination graph.
int renames_subgraph_names = 0;
- auto& new_implicit_defs = node->MutableImplicitInputDefs();
- for (auto& input_def : new_implicit_defs) {
+ auto& implicit_defs = node->MutableImplicitInputDefs();
+ for (auto& input_def : implicit_defs) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
input_def = hit->second;
@@ -4298,7 +4327,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
new_node.MutableSubgraphs() = std::move(subgraphs);
new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph());
- new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs);
+ new_node.MutableImplicitInputDefs() = std::move(implicit_defs);
}
new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes());
diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc
index 17b26ed7ca..ef6e2d531b 100755
--- a/onnxruntime/test/optimizer/graph_transform_test.cc
+++ b/onnxruntime/test/optimizer/graph_transform_test.cc
@@ -1176,6 +1176,162 @@ TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges)
ASSERT_EQ(op_to_count["Cast"], 2);
}
+TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddleArgNonExisting) {
+ // This model has a Resize() call with a middle argument non-existing.
+ // We want to make sure that the input edges for that Resize() node
+ // are properly rebuilt with a middle argument non-existing
+ // during If constant folding
+ // This test is only valid if Resize() node resides in the nested subgraph which gets inlined
+ // however, the destination graph must not be the main graph. Then we test that the edges are rebuild
+ // properly. Also Resize() should not be the first node in the resulting subgraph, so it has edges
+ const char* code = R"(
+ <
+ ir_version: 8,
+ opset_import: [ "" : 16, "local" : 1 ]
+ >
+ agraph (float[128] x, float[128] x1) => (float[N] y)
+ {
+ y = local.aten_gather (x, x1)
+ }
+ <
+ opset_import: [ "" : 16, "local" : 1],
+ domain: "local"
+ >
+ aten_gather (self, index) => (result_16)
+ {
+ resize_scales = Constant ()
+ tmp_0 = Size (index)
+ int64_0 = Constant ()
+ int64_0_cast = CastLike (int64_0, tmp_0)
+ cond = Equal (tmp_0, int64_0_cast)
+ result_16 = If (cond) ( result) {
+ result = Identity (self)
+ }, else_branch: graph = elseGraph_10 () => ( result_15) {
+ tmp_1 = Shape (self)
+ tmp_2 = Size (tmp_1)
+ int64_0_3 = Constant ()
+ int64_0_3_cast = CastLike (int64_0_3, tmp_2)
+ cond_4 = Equal (tmp_2, int64_0_3_cast)
+ self_8 = If (cond_4) ( self_6) {
+ tmp_5 = Constant ()
+ self_6 = Reshape (self, tmp_5)
+ }, else_branch: graph = elseGraph_13 () => ( self_7) {
+ self_71 = Mul(self, self)
+ float_size = CastLike (tmp_0, resize_scales)
+ non_constant_resize_scales = Mul(float_size, resize_scales)
+ self_7 = Resize(self_71,, non_constant_resize_scales)
+ }>
+ tmp_9 = Size (index)
+ int64_0_10 = Constant ()
+ int64_0_10_cast = CastLike (int64_0_10, tmp_9)
+ cond_11 = Equal (tmp_9, int64_0_10_cast)
+ result_15 = If (cond_11) ( result_12) {
+ result_12 = CastLike (index, self_8)
+ }, else_branch: graph = elseGraph_15 () => ( result_14) {
+ index_13 = Cast (index)
+ result_14 = GatherElements (self_8, index_13)
+ }>
+ }>
+ }
+ )";
+
+ /** Optimized model graph
+ <
+ ir_version: 8,
+ opset_import: ["" : 16,
+ "local" : 1,
+ "com.microsoft.nchwc" : 1,
+ "ai.onnx.ml" : 4,
+ "ai.onnx.training" : 1,
+ "ai.onnx.preview.training" : 1,
+ "com.microsoft" : 1,
+ "com.microsoft.experimental" : 1, "org.pytorch.aten" : 1]
+ >
+ agraph (float[128] x, float[128] x1) => (float[128] y)
+
+ {
+ _inlfunc_aten_gather_tmp_0 = Size (x1)
+ _inlfunc_aten_gather_cond = Equal (_inlfunc_aten_gather_tmp_0, ortshared_7_0_1_0_token_8)
+ y = If (_inlfunc_aten_gather_cond)
+ (float[128] _inlfunc_aten_gather_result) {
+ _inlfunc_aten_gather_result = Identity (x)
+ }, else_branch: graph = elseGraph_10 () => (float[128] _inlfunc_aten_gather_result_15)
+
+ {
+ _if_else_branch__inlfunc_aten_gather_self_71 = Mul (x, x)
+ _if_else_branch__inlfunc_aten_gather_float_size = Cast (_inlfunc_aten_gather_tmp_0)
+ _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales = Mul (
+ _if_else_branch__inlfunc_aten_gather_float_size, _inlfunc_aten_gather_resize_scales)
+ _inlfunc_aten_gather_self_8 = Resize (
+ _if_else_branch__inlfunc_aten_gather_self_71, ,
+ _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales)
+ _inlfunc_aten_gather_tmp_9 = Size (x1)
+ _inlfunc_aten_gather_cond_11 = Equal (_inlfunc_aten_gather_tmp_9, _inlfunc_aten_gather_int64_0_10)
+ _inlfunc_aten_gather_result_15 = If (_inlfunc_aten_gather_cond_11)
+ (float[128] _inlfunc_aten_gather_result_12) {
+ _inlfunc_aten_gather_result_12 = Cast (x1)
+ }, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) {
+ _inlfunc_aten_gather_index_13 = Cast (x1)
+ _inlfunc_aten_gather_result_14 = GatherElements (
+ _inlfunc_aten_gather_self_8, _inlfunc_aten_gather_index_13)
+ }>
+ }>
+ }
+
+ */
+
+ ONNX_NAMESPACE::OnnxParser parser(code);
+ ONNX_NAMESPACE::ModelProto model_proto;
+ auto parse_status = parser.Parse(model_proto);
+ ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage();
+ ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";
+
+ std::string serialized_model;
+ const bool serialization_status = model_proto.SerializeToString(&serialized_model);
+ ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string";
+
+ // AOT inlining is necessary in this case, so the If nodes within the function
+ // are brought out to the outer scope. So we load this into a session object.
+ SessionOptions session_options;
+ InferenceSessionWrapper session_object{session_options, GetEnvironment()};
+ std::stringstream sstr(serialized_model);
+ ASSERT_STATUS_OK(session_object.Load(sstr));
+ ASSERT_STATUS_OK(session_object.Initialize());
+
+ // Let's verify the correctness of the rebuild edges in the Resize node that still
+ // resides within an if else subgraph.
+ auto& graph = session_object.GetModel().MainGraph();
+ auto op_to_count = CountOpsInGraph(graph);
+ ASSERT_EQ(op_to_count["If"], 2);
+ ASSERT_EQ(op_to_count["Resize"], 1);
+
+ auto if_node = std::find_if(graph.Nodes().begin(), graph.Nodes().end(),
+ [](const auto& node) { return node.OpType() == "If"; });
+ ASSERT_NE(graph.Nodes().cend(), if_node);
+ // Resize is in the else branch
+ auto subgraph_map = if_node->GetAttributeNameToSubgraphMap();
+ auto branch = subgraph_map.find("else_branch");
+ ASSERT_NE(subgraph_map.cend(), branch);
+
+ auto resize_node = std::find_if(branch->second->Nodes().begin(), branch->second->Nodes().end(),
+ [](const auto& node) { return node.OpType() == "Resize"; });
+ ASSERT_NE(branch->second->Nodes().cend(), resize_node);
+
+ // Check the edges
+ ASSERT_EQ(2U, resize_node->GetInputEdgesCount());
+ // Should have input edges with arg_pos 0 and 2
+ // With 1 is missing
+ InlinedHashSet dest_edges;
+ auto zero_edge = resize_node->InputEdgesBegin();
+ dest_edges.insert(zero_edge->GetDstArgIndex());
+ ++zero_edge;
+ dest_edges.insert(zero_edge->GetDstArgIndex());
+ ASSERT_TRUE(dest_edges.find(0) != dest_edges.end());
+ ASSERT_TRUE(dest_edges.find(2) != dest_edges.end());
+}
+
// Check transformations in the case of a subgraph with constant inputs.
TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx";