diff --git a/onnxruntime/core/optimizer/identity_elimination.cc b/onnxruntime/core/optimizer/identity_elimination.cc index 944d01928d..77a58f7c57 100644 --- a/onnxruntime/core/optimizer/identity_elimination.cc +++ b/onnxruntime/core/optimizer/identity_elimination.cc @@ -10,16 +10,87 @@ namespace onnxruntime { -Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { - if (graph_utils::RemoveNode(graph, node)) { - rule_effect = RewriteRuleEffect::kRemovedCurrentNode; - } +/** + Case to eliminate Identity node when + - the input nodearg has only one consumer, which is the Identity itself + - the input def is not a graph output + + For examples: + OK to eliminate: + + Identity output is another node, and the Identity is the only consumer of X + X ---> Identity ---> Y where Y could be graph output + + Identity input arg is not shared with other output arg of X + + (arg0) ---> Identity0 ---> Z + | + X (arg1) ---> Identity1 ---> Y + + Not OK to eliminate: + + Identity input arg, i.e., arg0, is also an input arg of other Identity + + (arg0) ---> Identity0 ---> Z + | + X (arg0) ---> Identity1 ---> Y + + Identity input def, i.e., def0, is also a graph output + + (def0) ---> Z where Z is graph output + | + X (def0/arg0) ---> Identity ---> Y + */ +Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + if (graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph_utils::RemoveNode(graph, node)) { + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + } + } else { + // keep a reference of output def to the graph output + NodeArg* output = node.MutableOutputDefs()[0]; + const Node* p_input_node = graph_utils::GetInputNode(node, 0); + // get mutable input node + Node& input_node = *graph.GetNode(p_input_node->Index()); + int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(input_node, node.MutableInputDefs()[0]->Name()); + // remove Identity node and its input edge + graph.RemoveNode(node.Index()); + // update input node's output def to the graph output + input_node.MutableOutputDefs()[output_idx] = output; + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + } return Status::OK(); } bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { - return graph_utils::CanRemoveNode(graph, node, logger); + if (graph_utils::CanRemoveNode(graph, node, logger)) { + return true; + } + + // relax the condition if Identity is connecting to graph output + if (node.GetOutputEdgesCount() != 0 || node.OutputDefs().size() != 1 || + graph.GetNodeOutputsInGraphOutputs(node).empty()) + return false; + + const Node* p_input_node = graph_utils::GetInputNode(node, 0); + if (p_input_node == nullptr) + return false; + + // skip if the src arg is also a graph output + int src_arg_index = graph_utils::GetNodeOutputIndexFromOutputName(*p_input_node, node.InputDefs()[0]->Name()); + if (graph.IsOutput(p_input_node->OutputDefs()[src_arg_index])) + return false; + + // count how many consumers are sharing the same src arg + int count = 0; + for (auto it = p_input_node->OutputEdgesBegin(), end = p_input_node->OutputEdgesEnd(); it != end; ++it) { + if (it->GetSrcArgIndex() == src_arg_index) { + count++; + } + } + // condition not met if there are more than 1 consumer for the same src arg + if (count > 1) + return false; + + return true; } } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 5a98c9d2db..fa79304ef3 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -100,6 +100,70 @@ TEST_F(GraphTransformationTests, IdentityElimination) { ASSERT_TRUE(op_to_count["Identity"] == 0); } +TEST_F(GraphTransformationTests, IdentityEliminationWithGraphOutput) { + auto model_uri = MODEL_FOLDER "abs-id.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 1); + + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 0); +} + +TEST_F(GraphTransformationTests, IdentityWithSharedNodeArgNotEliminated) { + auto model_uri = MODEL_FOLDER "id-elim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 2); + ASSERT_TRUE(op_to_count["Add"] == 2); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // after CommonSubexpressionElimination, Add would have 1 output def and 2 edges + // each edge would share the same input node arg 0. Thus after execution, only one of the 2 outputs + // has data. Thus skip. + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 2); + ASSERT_TRUE(op_to_count["Add"] == 1); +} + +TEST_F(GraphTransformationTests, IdentityInputIsGraphOutputNotEliminated) { + auto model_uri = MODEL_FOLDER "scan9_sum.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 1); + + // tips: to dump the subgraph, can use python tool - dump_subgraphs.py + // or click on one of the input to see the drop down graph list and view subgraph + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // Identity's input in subgraph is also graph output. Thus skip. + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 1); +} + TEST_F(GraphTransformationTests, DropoutElimination) { auto model_uri = MODEL_FOLDER "dropout.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/id-elim.onnx b/onnxruntime/test/testdata/transform/id-elim.onnx new file mode 100644 index 0000000000..95c7ad12eb Binary files /dev/null and b/onnxruntime/test/testdata/transform/id-elim.onnx differ diff --git a/onnxruntime/test/testdata/transform/id-elim.py b/onnxruntime/test/testdata/transform/id-elim.py new file mode 100644 index 0000000000..105e3a1071 --- /dev/null +++ b/onnxruntime/test/testdata/transform/id-elim.py @@ -0,0 +1,41 @@ +import onnx +from onnx import helper +from onnx import TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np + +X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [4, 4]) +X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [4, 4]) +Y1 = helper.make_tensor_value_info('output1', TensorProto.INT64, [4, 4]) +Y2 = helper.make_tensor_value_info('output2', TensorProto.INT64, [4, 4]) + +add1 = helper.make_node('Add', ['x1', 'x2'], ['add1'], name='add1') +add2 = helper.make_node('Add', ['x1', 'x2'], ['add2'], name='add2') +id1 = helper.make_node('Identity', ['add1'], ['output1'], name='id1') +id2 = helper.make_node('Identity', ['add2'], ['output2'], name='id2') + +# Create the graph (GraphProto) +graph_def = helper.make_graph( + [add1, add2, id1, id2], + 'identity_elimination_model', + [X1, X2], + [Y1, Y2] +) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = 'com.microsoft' + +opsets.append(msdomain) +kwargs={} +kwargs['opset_imports'] = opsets + +# Create the model (ModelProto) +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +onnx.save(model_def, 'id-elim.onnx') diff --git a/onnxruntime/test/testdata/transform/id-scan9_sum.py b/onnxruntime/test/testdata/transform/id-scan9_sum.py new file mode 100644 index 0000000000..798c4afd39 --- /dev/null +++ b/onnxruntime/test/testdata/transform/id-scan9_sum.py @@ -0,0 +1,61 @@ +import onnx +from onnx import helper +from onnx import TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np + +initial = helper.make_tensor_value_info('initial', TensorProto.FLOAT, [2]) +x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 2]) +y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2]) +z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [3, 2]) + +sum_in = helper.make_tensor_value_info('sum_in', TensorProto.FLOAT, [2]) +next = helper.make_tensor_value_info('next', TensorProto.FLOAT, [2]) +sum_out = helper.make_tensor_value_info('sum_out', TensorProto.FLOAT, [2]) +scan_out = helper.make_tensor_value_info('scan_out', TensorProto.FLOAT, [2]) + +add_node = helper.make_node( + 'Add', + inputs=['sum_in', 'next'], + outputs=['sum_out'] +) +id_node = helper.make_node( + 'Identity', + inputs=['sum_out'], + outputs=['scan_out'] +) +scan_body = helper.make_graph( + [add_node, id_node], + 'scan_body', + [sum_in, next], + [sum_out, scan_out] +) +# create scan op node +scan_node = helper.make_node( + 'Scan', + inputs=['initial', 'x'], + outputs=['y', 'z'], + num_scan_inputs=1, + body=scan_body +) + +# Create the graph (GraphProto) +graph_def = helper.make_graph( + [scan_node], + 'test_scan9_sum', + [initial, x], + [y, z] +) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 9 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +kwargs={} +kwargs['opset_imports'] = opsets + +# Create the model (ModelProto) +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +onnx.save(model_def, 'scan9_sum.onnx') diff --git a/onnxruntime/test/testdata/transform/scan9_sum.onnx b/onnxruntime/test/testdata/transform/scan9_sum.onnx new file mode 100644 index 0000000000..5a32b08b83 Binary files /dev/null and b/onnxruntime/test/testdata/transform/scan9_sum.onnx differ diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 391cf92c76..62e4160ea6 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -779,6 +779,30 @@ def test_mixed_nnmodule_ortmodules_training(): _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3) +def test_identity_elimination(): + class NeuralNetSimpleIdentity(torch.nn.Module): + def __init__(self, input_size, num_classes): + super(NeuralNetSimpleIdentity, self).__init__() + + self.fc = torch.nn.Linear(input_size, num_classes) + + # Identity node will be created between ReduceSum and graph output + # and then eliminated after transformation + def forward(self, x): + y = self.fc(x) + z = y + return z + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetSimpleIdentity(D_in, D_out).to(device) + model = ORTModule(model) + x = torch.randn(N, D_in, device=device) + output = model(x) + + # Make sure model runs OK + assert output is not None + def test_ortmodule_inputs_with_dynamic_shape(): D_in, H, D_out = 784, 500, 10