Identity elimination with graph output (#7312)

* Identity removal

* fix build

* fix build

* fix build

* fix builld

* UTs

* fix UT

* fix UTs

* per comments

* fix UTs

* fix UTs

* per comments

Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
ytaous 2021-04-19 16:36:35 -07:00 committed by GitHub
parent 265db2ad96
commit 7abe1fd392
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 266 additions and 5 deletions

View file

@ -10,16 +10,87 @@
namespace onnxruntime {
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
/**
Case to eliminate Identity node when
- the input nodearg has only one consumer, which is the Identity itself
- the input def is not a graph output
For examples:
OK to eliminate:
Identity output is another node, and the Identity is the only consumer of X
X ---> Identity ---> Y where Y could be graph output
Identity input arg is not shared with other output arg of X
+ (arg0) ---> Identity0 ---> Z
|
X (arg1) ---> Identity1 ---> Y
Not OK to eliminate:
Identity input arg, i.e., arg0, is also an input arg of other Identity
+ (arg0) ---> Identity0 ---> Z
|
X (arg0) ---> Identity1 ---> Y
Identity input def, i.e., def0, is also a graph output
+ (def0) ---> Z where Z is graph output
|
X (def0/arg0) ---> Identity ---> Y
*/
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
} else {
// keep a reference of output def to the graph output
NodeArg* output = node.MutableOutputDefs()[0];
const Node* p_input_node = graph_utils::GetInputNode(node, 0);
// get mutable input node
Node& input_node = *graph.GetNode(p_input_node->Index());
int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(input_node, node.MutableInputDefs()[0]->Name());
// remove Identity node and its input edge
graph.RemoveNode(node.Index());
// update input node's output def to the graph output
input_node.MutableOutputDefs()[output_idx] = output;
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
return Status::OK();
}
bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
return graph_utils::CanRemoveNode(graph, node, logger);
if (graph_utils::CanRemoveNode(graph, node, logger)) {
return true;
}
// relax the condition if Identity is connecting to graph output
if (node.GetOutputEdgesCount() != 0 || node.OutputDefs().size() != 1 ||
graph.GetNodeOutputsInGraphOutputs(node).empty())
return false;
const Node* p_input_node = graph_utils::GetInputNode(node, 0);
if (p_input_node == nullptr)
return false;
// skip if the src arg is also a graph output
int src_arg_index = graph_utils::GetNodeOutputIndexFromOutputName(*p_input_node, node.InputDefs()[0]->Name());
if (graph.IsOutput(p_input_node->OutputDefs()[src_arg_index]))
return false;
// count how many consumers are sharing the same src arg
int count = 0;
for (auto it = p_input_node->OutputEdgesBegin(), end = p_input_node->OutputEdgesEnd(); it != end; ++it) {
if (it->GetSrcArgIndex() == src_arg_index) {
count++;
}
}
// condition not met if there are more than 1 consumer for the same src arg
if (count > 1)
return false;
return true;
}
} // namespace onnxruntime

View file

@ -100,6 +100,70 @@ TEST_F(GraphTransformationTests, IdentityElimination) {
ASSERT_TRUE(op_to_count["Identity"] == 0);
}
TEST_F(GraphTransformationTests, IdentityEliminationWithGraphOutput) {
auto model_uri = MODEL_FOLDER "abs-id.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Identity"] == 1);
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(onnxruntime::make_unique<EliminateIdentity>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Identity"] == 0);
}
TEST_F(GraphTransformationTests, IdentityWithSharedNodeArgNotEliminated) {
auto model_uri = MODEL_FOLDER "id-elim.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Identity"] == 2);
ASSERT_TRUE(op_to_count["Add"] == 2);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<CommonSubexpressionElimination>(), TransformerLevel::Level1);
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(onnxruntime::make_unique<EliminateIdentity>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
// after CommonSubexpressionElimination, Add would have 1 output def and 2 edges
// each edge would share the same input node arg 0. Thus after execution, only one of the 2 outputs
// has data. Thus skip.
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Identity"] == 2);
ASSERT_TRUE(op_to_count["Add"] == 1);
}
TEST_F(GraphTransformationTests, IdentityInputIsGraphOutputNotEliminated) {
auto model_uri = MODEL_FOLDER "scan9_sum.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Identity"] == 1);
// tips: to dump the subgraph, can use python tool - dump_subgraphs.py
// or click on one of the input to see the drop down graph list and view subgraph
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(onnxruntime::make_unique<EliminateIdentity>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
// Identity's input in subgraph is also graph output. Thus skip.
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Identity"] == 1);
}
TEST_F(GraphTransformationTests, DropoutElimination) {
auto model_uri = MODEL_FOLDER "dropout.onnx";
std::shared_ptr<Model> model;

Binary file not shown.

View file

@ -0,0 +1,41 @@
import onnx
from onnx import helper
from onnx import TensorProto, GraphProto, OperatorSetIdProto
from onnx import numpy_helper
import numpy as np
X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [4, 4])
X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [4, 4])
Y1 = helper.make_tensor_value_info('output1', TensorProto.INT64, [4, 4])
Y2 = helper.make_tensor_value_info('output2', TensorProto.INT64, [4, 4])
add1 = helper.make_node('Add', ['x1', 'x2'], ['add1'], name='add1')
add2 = helper.make_node('Add', ['x1', 'x2'], ['add2'], name='add2')
id1 = helper.make_node('Identity', ['add1'], ['output1'], name='id1')
id2 = helper.make_node('Identity', ['add2'], ['output2'], name='id2')
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[add1, add2, id1, id2],
'identity_elimination_model',
[X1, X2],
[Y1, Y2]
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = 'com.microsoft'
opsets.append(msdomain)
kwargs={}
kwargs['opset_imports'] = opsets
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
onnx.save(model_def, 'id-elim.onnx')

View file

@ -0,0 +1,61 @@
import onnx
from onnx import helper
from onnx import TensorProto, GraphProto, OperatorSetIdProto
from onnx import numpy_helper
import numpy as np
initial = helper.make_tensor_value_info('initial', TensorProto.FLOAT, [2])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2])
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [3, 2])
sum_in = helper.make_tensor_value_info('sum_in', TensorProto.FLOAT, [2])
next = helper.make_tensor_value_info('next', TensorProto.FLOAT, [2])
sum_out = helper.make_tensor_value_info('sum_out', TensorProto.FLOAT, [2])
scan_out = helper.make_tensor_value_info('scan_out', TensorProto.FLOAT, [2])
add_node = helper.make_node(
'Add',
inputs=['sum_in', 'next'],
outputs=['sum_out']
)
id_node = helper.make_node(
'Identity',
inputs=['sum_out'],
outputs=['scan_out']
)
scan_body = helper.make_graph(
[add_node, id_node],
'scan_body',
[sum_in, next],
[sum_out, scan_out]
)
# create scan op node
scan_node = helper.make_node(
'Scan',
inputs=['initial', 'x'],
outputs=['y', 'z'],
num_scan_inputs=1,
body=scan_body
)
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[scan_node],
'test_scan9_sum',
[initial, x],
[y, z]
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 9
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
kwargs={}
kwargs['opset_imports'] = opsets
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
onnx.save(model_def, 'scan9_sum.onnx')

Binary file not shown.

View file

@ -779,6 +779,30 @@ def test_mixed_nnmodule_ortmodules_training():
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3)
def test_identity_elimination():
class NeuralNetSimpleIdentity(torch.nn.Module):
def __init__(self, input_size, num_classes):
super(NeuralNetSimpleIdentity, self).__init__()
self.fc = torch.nn.Linear(input_size, num_classes)
# Identity node will be created between ReduceSum and graph output
# and then eliminated after transformation
def forward(self, x):
y = self.fc(x)
z = y
return z
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSimpleIdentity(D_in, D_out).to(device)
model = ORTModule(model)
x = torch.randn(N, D_in, device=device)
output = model(x)
# Make sure model runs OK
assert output is not None
def test_ortmodule_inputs_with_dynamic_shape():
D_in, H, D_out = 784, 500, 10