mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Identity elimination with graph output (#7312)
* Identity removal * fix build * fix build * fix build * fix builld * UTs * fix UT * fix UTs * per comments * fix UTs * fix UTs * per comments Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
265db2ad96
commit
7abe1fd392
7 changed files with 266 additions and 5 deletions
|
|
@ -10,16 +10,87 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
/**
|
||||
Case to eliminate Identity node when
|
||||
- the input nodearg has only one consumer, which is the Identity itself
|
||||
- the input def is not a graph output
|
||||
|
||||
For examples:
|
||||
|
||||
OK to eliminate:
|
||||
|
||||
Identity output is another node, and the Identity is the only consumer of X
|
||||
X ---> Identity ---> Y where Y could be graph output
|
||||
|
||||
Identity input arg is not shared with other output arg of X
|
||||
+ (arg0) ---> Identity0 ---> Z
|
||||
|
|
||||
X (arg1) ---> Identity1 ---> Y
|
||||
|
||||
Not OK to eliminate:
|
||||
|
||||
Identity input arg, i.e., arg0, is also an input arg of other Identity
|
||||
+ (arg0) ---> Identity0 ---> Z
|
||||
|
|
||||
X (arg0) ---> Identity1 ---> Y
|
||||
|
||||
Identity input def, i.e., def0, is also a graph output
|
||||
+ (def0) ---> Z where Z is graph output
|
||||
|
|
||||
X (def0/arg0) ---> Identity ---> Y
|
||||
*/
|
||||
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
if (graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
} else {
|
||||
// keep a reference of output def to the graph output
|
||||
NodeArg* output = node.MutableOutputDefs()[0];
|
||||
const Node* p_input_node = graph_utils::GetInputNode(node, 0);
|
||||
// get mutable input node
|
||||
Node& input_node = *graph.GetNode(p_input_node->Index());
|
||||
int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(input_node, node.MutableInputDefs()[0]->Name());
|
||||
// remove Identity node and its input edge
|
||||
graph.RemoveNode(node.Index());
|
||||
// update input node's output def to the graph output
|
||||
input_node.MutableOutputDefs()[output_idx] = output;
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
return graph_utils::CanRemoveNode(graph, node, logger);
|
||||
if (graph_utils::CanRemoveNode(graph, node, logger)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// relax the condition if Identity is connecting to graph output
|
||||
if (node.GetOutputEdgesCount() != 0 || node.OutputDefs().size() != 1 ||
|
||||
graph.GetNodeOutputsInGraphOutputs(node).empty())
|
||||
return false;
|
||||
|
||||
const Node* p_input_node = graph_utils::GetInputNode(node, 0);
|
||||
if (p_input_node == nullptr)
|
||||
return false;
|
||||
|
||||
// skip if the src arg is also a graph output
|
||||
int src_arg_index = graph_utils::GetNodeOutputIndexFromOutputName(*p_input_node, node.InputDefs()[0]->Name());
|
||||
if (graph.IsOutput(p_input_node->OutputDefs()[src_arg_index]))
|
||||
return false;
|
||||
|
||||
// count how many consumers are sharing the same src arg
|
||||
int count = 0;
|
||||
for (auto it = p_input_node->OutputEdgesBegin(), end = p_input_node->OutputEdgesEnd(); it != end; ++it) {
|
||||
if (it->GetSrcArgIndex() == src_arg_index) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
// condition not met if there are more than 1 consumer for the same src arg
|
||||
if (count > 1)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -100,6 +100,70 @@ TEST_F(GraphTransformationTests, IdentityElimination) {
|
|||
ASSERT_TRUE(op_to_count["Identity"] == 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, IdentityEliminationWithGraphOutput) {
|
||||
auto model_uri = MODEL_FOLDER "abs-id.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 1);
|
||||
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(onnxruntime::make_unique<EliminateIdentity>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, IdentityWithSharedNodeArgNotEliminated) {
|
||||
auto model_uri = MODEL_FOLDER "id-elim.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 2);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 2);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<CommonSubexpressionElimination>(), TransformerLevel::Level1);
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(onnxruntime::make_unique<EliminateIdentity>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
// after CommonSubexpressionElimination, Add would have 1 output def and 2 edges
|
||||
// each edge would share the same input node arg 0. Thus after execution, only one of the 2 outputs
|
||||
// has data. Thus skip.
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 2);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, IdentityInputIsGraphOutputNotEliminated) {
|
||||
auto model_uri = MODEL_FOLDER "scan9_sum.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 1);
|
||||
|
||||
// tips: to dump the subgraph, can use python tool - dump_subgraphs.py
|
||||
// or click on one of the input to see the drop down graph list and view subgraph
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(onnxruntime::make_unique<EliminateIdentity>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
// Identity's input in subgraph is also graph output. Thus skip.
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, DropoutElimination) {
|
||||
auto model_uri = MODEL_FOLDER "dropout.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/id-elim.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/id-elim.onnx
vendored
Normal file
Binary file not shown.
41
onnxruntime/test/testdata/transform/id-elim.py
vendored
Normal file
41
onnxruntime/test/testdata/transform/id-elim.py
vendored
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [4, 4])
|
||||
X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [4, 4])
|
||||
Y1 = helper.make_tensor_value_info('output1', TensorProto.INT64, [4, 4])
|
||||
Y2 = helper.make_tensor_value_info('output2', TensorProto.INT64, [4, 4])
|
||||
|
||||
add1 = helper.make_node('Add', ['x1', 'x2'], ['add1'], name='add1')
|
||||
add2 = helper.make_node('Add', ['x1', 'x2'], ['add2'], name='add2')
|
||||
id1 = helper.make_node('Identity', ['add1'], ['output1'], name='id1')
|
||||
id2 = helper.make_node('Identity', ['add2'], ['output2'], name='id2')
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[add1, add2, id1, id2],
|
||||
'identity_elimination_model',
|
||||
[X1, X2],
|
||||
[Y1, Y2]
|
||||
)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = 'com.microsoft'
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs={}
|
||||
kwargs['opset_imports'] = opsets
|
||||
|
||||
# Create the model (ModelProto)
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
onnx.save(model_def, 'id-elim.onnx')
|
||||
61
onnxruntime/test/testdata/transform/id-scan9_sum.py
vendored
Normal file
61
onnxruntime/test/testdata/transform/id-scan9_sum.py
vendored
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
initial = helper.make_tensor_value_info('initial', TensorProto.FLOAT, [2])
|
||||
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 2])
|
||||
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2])
|
||||
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [3, 2])
|
||||
|
||||
sum_in = helper.make_tensor_value_info('sum_in', TensorProto.FLOAT, [2])
|
||||
next = helper.make_tensor_value_info('next', TensorProto.FLOAT, [2])
|
||||
sum_out = helper.make_tensor_value_info('sum_out', TensorProto.FLOAT, [2])
|
||||
scan_out = helper.make_tensor_value_info('scan_out', TensorProto.FLOAT, [2])
|
||||
|
||||
add_node = helper.make_node(
|
||||
'Add',
|
||||
inputs=['sum_in', 'next'],
|
||||
outputs=['sum_out']
|
||||
)
|
||||
id_node = helper.make_node(
|
||||
'Identity',
|
||||
inputs=['sum_out'],
|
||||
outputs=['scan_out']
|
||||
)
|
||||
scan_body = helper.make_graph(
|
||||
[add_node, id_node],
|
||||
'scan_body',
|
||||
[sum_in, next],
|
||||
[sum_out, scan_out]
|
||||
)
|
||||
# create scan op node
|
||||
scan_node = helper.make_node(
|
||||
'Scan',
|
||||
inputs=['initial', 'x'],
|
||||
outputs=['y', 'z'],
|
||||
num_scan_inputs=1,
|
||||
body=scan_body
|
||||
)
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[scan_node],
|
||||
'test_scan9_sum',
|
||||
[initial, x],
|
||||
[y, z]
|
||||
)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 9
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
kwargs={}
|
||||
kwargs['opset_imports'] = opsets
|
||||
|
||||
# Create the model (ModelProto)
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
onnx.save(model_def, 'scan9_sum.onnx')
|
||||
BIN
onnxruntime/test/testdata/transform/scan9_sum.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/scan9_sum.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -779,6 +779,30 @@ def test_mixed_nnmodule_ortmodules_training():
|
|||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3)
|
||||
|
||||
def test_identity_elimination():
|
||||
class NeuralNetSimpleIdentity(torch.nn.Module):
|
||||
def __init__(self, input_size, num_classes):
|
||||
super(NeuralNetSimpleIdentity, self).__init__()
|
||||
|
||||
self.fc = torch.nn.Linear(input_size, num_classes)
|
||||
|
||||
# Identity node will be created between ReduceSum and graph output
|
||||
# and then eliminated after transformation
|
||||
def forward(self, x):
|
||||
y = self.fc(x)
|
||||
z = y
|
||||
return z
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSimpleIdentity(D_in, D_out).to(device)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
output = model(x)
|
||||
|
||||
# Make sure model runs OK
|
||||
assert output is not None
|
||||
|
||||
def test_ortmodule_inputs_with_dynamic_shape():
|
||||
D_in, H, D_out = 784, 500, 10
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue