mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Create edges with arg positons correctly accounting for non-existing args (#18462)
### Description Truncate traling non-existing arguments. Make sure we do not skip on the non-existing arguments in the middle, because shape inferece relies on their proper position. This also affects the argument position in the Edges that must be properly rebuilt each time If node branch is inlined. Make sure that when we rename Defs in subgraphs, new renamed defs are created in those subgraphs instead of pointing to outer scope defs. Add unit test. ### Motivation and Context This is a follow up for https://github.com/microsoft/onnxruntime/pull/18105 Currently, the non-trailing arguments are simply ignored and the edges are created with potentially incorrect positions.
This commit is contained in:
parent
247ce21859
commit
cc542024ce
3 changed files with 217 additions and 33 deletions
1
cmake/external/abseil-cpp.natvis
vendored
1
cmake/external/abseil-cpp.natvis
vendored
|
|
@ -30,7 +30,6 @@
|
|||
<Intrinsic Name="_capacity" Expression="_commonfields().capacity_"/>
|
||||
<Intrinsic Name="_control" Expression="_commonfields().control_"/>
|
||||
<Intrinsic Name="_slots" Expression="(slot_type*)(_commonfields().slots_)"/>
|
||||
<DisplayString Condition="_size() == 0">empty</DisplayString>
|
||||
<DisplayString IncludeView="noparens">size={ _size() }</DisplayString>
|
||||
<DisplayString ExcludeView="noparens">size=({_size()})</DisplayString>
|
||||
<Expand>
|
||||
|
|
|
|||
|
|
@ -4062,7 +4062,9 @@ static void ReassignSubgraphDependentNodeArgs(const InlinedHashMap<std::string,
|
|||
if (input_def->Exists()) {
|
||||
auto hit = name_to_nodearg.find(input_def->Name());
|
||||
if (hit != name_to_nodearg.cend()) {
|
||||
input_def = hit->second;
|
||||
// Make sure we create a local to this subgraph definition
|
||||
const auto* new_name_arg = hit->second;
|
||||
input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -4088,7 +4090,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
|
|||
|
||||
Graph& graph_to_inline = *sub_graph;
|
||||
|
||||
std::string unique_id{if_node.Name()};
|
||||
std::string unique_id{"_if_"};
|
||||
if (condition_value) {
|
||||
unique_id.append(then_branch);
|
||||
} else {
|
||||
|
|
@ -4107,7 +4109,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
|
|||
// Reason: there are no explicit inputs to the subgraphs, and the subgraph's
|
||||
// implicit inputs must be covered by the implicit inputs of the If node.
|
||||
InlinedHashMap<std::string_view, NodeArg*> outer_scope_values;
|
||||
const auto if_implicit_inputs = if_node.MutableImplicitInputDefs();
|
||||
const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs();
|
||||
outer_scope_values.reserve(if_implicit_inputs.size());
|
||||
|
||||
for (auto* input : if_implicit_inputs) {
|
||||
|
|
@ -4121,8 +4123,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
|
|||
|
||||
// We are going to map the outputs of the graph to inline to the outputs of the If node.
|
||||
// They are assumed to be in the same order.
|
||||
const auto node_output_defs = if_node.MutableOutputDefs();
|
||||
const auto graph_output_defs = graph_to_inline.GetOutputs();
|
||||
const auto& node_output_defs = if_node.MutableOutputDefs();
|
||||
const auto& graph_output_defs = graph_to_inline.GetOutputs();
|
||||
for (size_t i = 0; i < graph_output_defs.size(); ++i) {
|
||||
name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]);
|
||||
}
|
||||
|
|
@ -4206,6 +4208,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
|
|||
}
|
||||
}
|
||||
|
||||
auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr);
|
||||
// We want to make sure we get nodes in topological order
|
||||
// because Constant folding may cause the nodes appear in
|
||||
// a different order.
|
||||
|
|
@ -4216,68 +4219,94 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
|
|||
auto* node = graph_to_inline.GetNode(node_idx);
|
||||
assert(node->OpType() != kConstant);
|
||||
|
||||
InlinedVector<NodeArg*> new_node_input_defs;
|
||||
for (const auto* input_def : node->InputDefs()) {
|
||||
// Inputs
|
||||
// Chop off trailing non-existing defs, but preserve non-existing in the middle
|
||||
auto& input_defs = node->MutableInputDefs();
|
||||
auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(),
|
||||
[](const NodeArg* node_arg) { return node_arg->Exists(); });
|
||||
input_defs.resize(std::distance(input_defs.begin(), last_existing.base()));
|
||||
|
||||
InlinedVector<NodeArg*> new_input_defs;
|
||||
for (auto* input_def : node->InputDefs()) {
|
||||
if (input_def->Exists()) {
|
||||
// Check if this is one of the implicit graph inputs
|
||||
// then leave the name as is and re-use the NodeArg
|
||||
// then re-assign the def to the outer scope value.
|
||||
const auto& input_name = input_def->Name();
|
||||
auto outer_hit = outer_scope_values.find(input_name);
|
||||
if (outer_hit != outer_scope_values.cend()) {
|
||||
new_node_input_defs.push_back(outer_hit->second);
|
||||
// get/create local definition
|
||||
NodeArg* outer_arg = outer_hit->second;
|
||||
auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto());
|
||||
new_input_defs.push_back(&this_scope_arg);
|
||||
} else {
|
||||
auto hit = name_to_nodearg.find(input_name);
|
||||
if (hit != name_to_nodearg.cend()) {
|
||||
// This is other node output, constant node or initializer that was renamed.
|
||||
new_node_input_defs.push_back(hit->second);
|
||||
// This is other node output in the dest graph,
|
||||
// constant node or initializer that was renamed.
|
||||
new_input_defs.push_back(hit->second);
|
||||
} else {
|
||||
ORT_THROW("Node's: ", node->Name(), " input: ", input_name,
|
||||
" is not If node's input or previous node output in this subgraph");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
new_input_defs.push_back(non_existing_arg);
|
||||
}
|
||||
}
|
||||
|
||||
InlinedVector<NodeArg*> new_node_output_defs;
|
||||
for (const auto* output_def : node->OutputDefs()) {
|
||||
const auto& output_name = output_def->Name();
|
||||
auto hit = name_to_nodearg.find(output_name);
|
||||
if (hit != name_to_nodearg.cend()) {
|
||||
// This is one of the graph outputs, we rename it to
|
||||
// If node output.
|
||||
new_node_output_defs.push_back(hit->second);
|
||||
// Outputs
|
||||
// Chop off trailing non-existing defs
|
||||
auto& output_defs = node->MutableOutputDefs();
|
||||
last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(),
|
||||
[](const NodeArg* node_arg) { return node_arg->Exists(); });
|
||||
output_defs.resize(std::distance(output_defs.begin(), last_existing.base()));
|
||||
|
||||
InlinedVector<NodeArg*> new_output_defs;
|
||||
for (auto* output_def : node->OutputDefs()) {
|
||||
if (output_def->Exists()) {
|
||||
const auto& output_name = output_def->Name();
|
||||
auto hit = name_to_nodearg.find(output_name);
|
||||
if (hit != name_to_nodearg.cend()) {
|
||||
// This is one of the If node outputs, simply reassign the def.
|
||||
// If node defs are already in the destination graph
|
||||
new_output_defs.push_back(hit->second);
|
||||
} else {
|
||||
// We generate an output to downstream nodes.
|
||||
auto new_name = GenerateNodeArgName(make_unique(output_name));
|
||||
NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
|
||||
new_output_defs.push_back(&new_arg);
|
||||
ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
|
||||
}
|
||||
} else {
|
||||
// We generate an output to downstream nodes.
|
||||
auto new_name = GenerateNodeArgName(make_unique(output_name));
|
||||
NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
|
||||
new_node_output_defs.push_back(&new_arg);
|
||||
ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
|
||||
new_output_defs.push_back(non_existing_arg);
|
||||
}
|
||||
}
|
||||
|
||||
const auto new_node_name = GenerateNodeName(make_unique(node->OpType()));
|
||||
Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(),
|
||||
new_node_input_defs,
|
||||
new_node_output_defs,
|
||||
new_input_defs,
|
||||
new_output_defs,
|
||||
nullptr,
|
||||
node->Domain());
|
||||
|
||||
new_node.SetSinceVersion(node->SinceVersion());
|
||||
new_node.op_ = node->op_;
|
||||
|
||||
if (!is_this_main_graph) {
|
||||
map_defs(new_node, input_args, true);
|
||||
map_defs(new_node, output_args, false);
|
||||
new_nodes.push_back(&new_node);
|
||||
}
|
||||
|
||||
new_node.SetSinceVersion(node->SinceVersion());
|
||||
new_node.op_ = node->op_;
|
||||
|
||||
if (node->ContainsSubgraph()) {
|
||||
auto& subgraphs = node->MutableSubgraphs();
|
||||
|
||||
// Check if any of this node implicit inputs of this graph is in the renaming map
|
||||
// that would mean they come from the destination graph, not from the parent
|
||||
// of the destination graph.
|
||||
int renames_subgraph_names = 0;
|
||||
auto& new_implicit_defs = node->MutableImplicitInputDefs();
|
||||
for (auto& input_def : new_implicit_defs) {
|
||||
auto& implicit_defs = node->MutableImplicitInputDefs();
|
||||
for (auto& input_def : implicit_defs) {
|
||||
auto hit = name_to_nodearg.find(input_def->Name());
|
||||
if (hit != name_to_nodearg.cend()) {
|
||||
input_def = hit->second;
|
||||
|
|
@ -4298,7 +4327,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
|
|||
|
||||
new_node.MutableSubgraphs() = std::move(subgraphs);
|
||||
new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph());
|
||||
new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs);
|
||||
new_node.MutableImplicitInputDefs() = std::move(implicit_defs);
|
||||
}
|
||||
|
||||
new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes());
|
||||
|
|
|
|||
|
|
@ -1176,6 +1176,162 @@ TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges)
|
|||
ASSERT_EQ(op_to_count["Cast"], 2);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddleArgNonExisting) {
|
||||
// This model has a Resize() call with a middle argument non-existing.
|
||||
// We want to make sure that the input edges for that Resize() node
|
||||
// are properly rebuilt with a middle argument non-existing
|
||||
// during If constant folding
|
||||
// This test is only valid if Resize() node resides in the nested subgraph which gets inlined
|
||||
// however, the destination graph must not be the main graph. Then we test that the edges are rebuild
|
||||
// properly. Also Resize() should not be the first node in the resulting subgraph, so it has edges
|
||||
const char* code = R"(
|
||||
<
|
||||
ir_version: 8,
|
||||
opset_import: [ "" : 16, "local" : 1 ]
|
||||
>
|
||||
agraph (float[128] x, float[128] x1) => (float[N] y)
|
||||
{
|
||||
y = local.aten_gather <dim: int = 1, sparse_grad: int = 0> (x, x1)
|
||||
}
|
||||
<
|
||||
opset_import: [ "" : 16, "local" : 1],
|
||||
domain: "local"
|
||||
>
|
||||
aten_gather <dim>(self, index) => (result_16)
|
||||
{
|
||||
resize_scales = Constant <value_floats: floats = [1.5]> ()
|
||||
tmp_0 = Size (index)
|
||||
int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()
|
||||
int64_0_cast = CastLike (int64_0, tmp_0)
|
||||
cond = Equal (tmp_0, int64_0_cast)
|
||||
result_16 = If (cond) <then_branch: graph = thenGraph_10 () => ( result) {
|
||||
result = Identity (self)
|
||||
}, else_branch: graph = elseGraph_10 () => ( result_15) {
|
||||
tmp_1 = Shape (self)
|
||||
tmp_2 = Size (tmp_1)
|
||||
int64_0_3 = Constant <value: tensor = int64 int64_0_3 {0}> ()
|
||||
int64_0_3_cast = CastLike (int64_0_3, tmp_2)
|
||||
cond_4 = Equal (tmp_2, int64_0_3_cast)
|
||||
self_8 = If (cond_4) <then_branch: graph = thenGraph_13 () => ( self_6) {
|
||||
tmp_5 = Constant <value_ints: ints = [-1]> ()
|
||||
self_6 = Reshape (self, tmp_5)
|
||||
}, else_branch: graph = elseGraph_13 () => ( self_7) {
|
||||
self_71 = Mul(self, self)
|
||||
float_size = CastLike (tmp_0, resize_scales)
|
||||
non_constant_resize_scales = Mul(float_size, resize_scales)
|
||||
self_7 = Resize(self_71,, non_constant_resize_scales)
|
||||
}>
|
||||
tmp_9 = Size (index)
|
||||
int64_0_10 = Constant <value: tensor = int64 int64_0_10 {0}> ()
|
||||
int64_0_10_cast = CastLike (int64_0_10, tmp_9)
|
||||
cond_11 = Equal (tmp_9, int64_0_10_cast)
|
||||
result_15 = If (cond_11) <then_branch: graph = thenGraph_15 () => ( result_12) {
|
||||
result_12 = CastLike (index, self_8)
|
||||
}, else_branch: graph = elseGraph_15 () => ( result_14) {
|
||||
index_13 = Cast <to: int = 7> (index)
|
||||
result_14 = GatherElements <axis: int = @dim> (self_8, index_13)
|
||||
}>
|
||||
}>
|
||||
}
|
||||
)";
|
||||
|
||||
/** Optimized model graph
|
||||
<
|
||||
ir_version: 8,
|
||||
opset_import: ["" : 16,
|
||||
"local" : 1,
|
||||
"com.microsoft.nchwc" : 1,
|
||||
"ai.onnx.ml" : 4,
|
||||
"ai.onnx.training" : 1,
|
||||
"ai.onnx.preview.training" : 1,
|
||||
"com.microsoft" : 1,
|
||||
"com.microsoft.experimental" : 1, "org.pytorch.aten" : 1]
|
||||
>
|
||||
agraph (float[128] x, float[128] x1) => (float[128] y)
|
||||
<float[1] _inlfunc_aten_gather_resize_scales = {1.5}, int64 ortshared_7_0_1_0_token_8 = {0}>
|
||||
{
|
||||
_inlfunc_aten_gather_tmp_0 = Size (x1)
|
||||
_inlfunc_aten_gather_cond = Equal (_inlfunc_aten_gather_tmp_0, ortshared_7_0_1_0_token_8)
|
||||
y = If (_inlfunc_aten_gather_cond) <then_branch: graph = thenGraph_10 () =>
|
||||
(float[128] _inlfunc_aten_gather_result) {
|
||||
_inlfunc_aten_gather_result = Identity (x)
|
||||
}, else_branch: graph = elseGraph_10 () => (float[128] _inlfunc_aten_gather_result_15)
|
||||
<int64 _inlfunc_aten_gather_int64_0_10 = {0}>
|
||||
{
|
||||
_if_else_branch__inlfunc_aten_gather_self_71 = Mul (x, x)
|
||||
_if_else_branch__inlfunc_aten_gather_float_size = Cast <to: int = 1> (_inlfunc_aten_gather_tmp_0)
|
||||
_if_else_branch__inlfunc_aten_gather_non_constant_resize_scales = Mul (
|
||||
_if_else_branch__inlfunc_aten_gather_float_size, _inlfunc_aten_gather_resize_scales)
|
||||
_inlfunc_aten_gather_self_8 = Resize <exclude_outside: int = 0, coordinate_transformation_mode:
|
||||
string = "half_pixel", cubic_coeff_a: float = -0.75, extrapolation_value: float = 0, mode:
|
||||
string = "nearest", nearest_mode: string = "round_prefer_floor"> (
|
||||
_if_else_branch__inlfunc_aten_gather_self_71, ,
|
||||
_if_else_branch__inlfunc_aten_gather_non_constant_resize_scales)
|
||||
_inlfunc_aten_gather_tmp_9 = Size (x1)
|
||||
_inlfunc_aten_gather_cond_11 = Equal (_inlfunc_aten_gather_tmp_9, _inlfunc_aten_gather_int64_0_10)
|
||||
_inlfunc_aten_gather_result_15 = If (_inlfunc_aten_gather_cond_11) <then_branch: graph = thenGraph_15 () =>
|
||||
(float[128] _inlfunc_aten_gather_result_12) {
|
||||
_inlfunc_aten_gather_result_12 = Cast <to: int = 1> (x1)
|
||||
}, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) {
|
||||
_inlfunc_aten_gather_index_13 = Cast <to: int = 7> (x1)
|
||||
_inlfunc_aten_gather_result_14 = GatherElements <axis: int = 1> (
|
||||
_inlfunc_aten_gather_self_8, _inlfunc_aten_gather_index_13)
|
||||
}>
|
||||
}>
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
ONNX_NAMESPACE::OnnxParser parser(code);
|
||||
ONNX_NAMESPACE::ModelProto model_proto;
|
||||
auto parse_status = parser.Parse(model_proto);
|
||||
ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage();
|
||||
ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";
|
||||
|
||||
std::string serialized_model;
|
||||
const bool serialization_status = model_proto.SerializeToString(&serialized_model);
|
||||
ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string";
|
||||
|
||||
// AOT inlining is necessary in this case, so the If nodes within the function
|
||||
// are brought out to the outer scope. So we load this into a session object.
|
||||
SessionOptions session_options;
|
||||
InferenceSessionWrapper session_object{session_options, GetEnvironment()};
|
||||
std::stringstream sstr(serialized_model);
|
||||
ASSERT_STATUS_OK(session_object.Load(sstr));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
||||
// Let's verify the correctness of the rebuild edges in the Resize node that still
|
||||
// resides within an if else subgraph.
|
||||
auto& graph = session_object.GetModel().MainGraph();
|
||||
auto op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["If"], 2);
|
||||
ASSERT_EQ(op_to_count["Resize"], 1);
|
||||
|
||||
auto if_node = std::find_if(graph.Nodes().begin(), graph.Nodes().end(),
|
||||
[](const auto& node) { return node.OpType() == "If"; });
|
||||
ASSERT_NE(graph.Nodes().cend(), if_node);
|
||||
// Resize is in the else branch
|
||||
auto subgraph_map = if_node->GetAttributeNameToSubgraphMap();
|
||||
auto branch = subgraph_map.find("else_branch");
|
||||
ASSERT_NE(subgraph_map.cend(), branch);
|
||||
|
||||
auto resize_node = std::find_if(branch->second->Nodes().begin(), branch->second->Nodes().end(),
|
||||
[](const auto& node) { return node.OpType() == "Resize"; });
|
||||
ASSERT_NE(branch->second->Nodes().cend(), resize_node);
|
||||
|
||||
// Check the edges
|
||||
ASSERT_EQ(2U, resize_node->GetInputEdgesCount());
|
||||
// Should have input edges with arg_pos 0 and 2
|
||||
// With 1 is missing
|
||||
InlinedHashSet<size_t> dest_edges;
|
||||
auto zero_edge = resize_node->InputEdgesBegin();
|
||||
dest_edges.insert(zero_edge->GetDstArgIndex());
|
||||
++zero_edge;
|
||||
dest_edges.insert(zero_edge->GetDstArgIndex());
|
||||
ASSERT_TRUE(dest_edges.find(0) != dest_edges.end());
|
||||
ASSERT_TRUE(dest_edges.find(2) != dest_edges.end());
|
||||
}
|
||||
|
||||
// Check transformations in the case of a subgraph with constant inputs.
|
||||
TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx";
|
||||
|
|
|
|||
Loading…
Reference in a new issue