mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Add support for more ops for padding elimination (#17217)
Add support for Gelu/ReduceMean/SimplifiedLayerNormalization for padding elimination
This commit is contained in:
parent
735cc8e6c8
commit
401129d484
2 changed files with 66 additions and 20 deletions
|
|
@ -300,13 +300,20 @@ void IterateSubgraphFromNode(Graph& graph,
|
|||
candidate_outputs.insert(cur);
|
||||
continue;
|
||||
}
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "LayerNormalization", {1, 17}, kOnnxDomain)) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "LayerNormalization", {1, 17}, kOnnxDomain) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "SimplifiedLayerNormalization", {1}, kOnnxDomain)) {
|
||||
if (subgraph.find(cur->MutableInputDefs()[0]) == subgraph.end()) {
|
||||
LOG_DEBUG_INFO(logger, "PaddingElimination::First input of Normalization: " + cur->Name() +
|
||||
" is not in subgraph.");
|
||||
candidate_outputs.insert(cur);
|
||||
continue;
|
||||
}
|
||||
if (!cur->InputDefs()[0]->Shape()) {
|
||||
LOG_DEBUG_INFO(logger, "PaddingElimination::First input of Normalization: " + cur->Name() +
|
||||
" has no shape.");
|
||||
candidate_outputs.insert(cur);
|
||||
continue;
|
||||
}
|
||||
auto axis = static_cast<int64_t>(cur->GetAttributes().at("axis").i());
|
||||
axis = axis < 0 ? axis + cur->InputDefs()[0]->Shape()->dim_size() : axis;
|
||||
if (axis < 2) {
|
||||
|
|
@ -322,7 +329,8 @@ void IterateSubgraphFromNode(Graph& graph,
|
|||
subgraph.insert(cur->MutableOutputDefs()[0]);
|
||||
subgraph.insert(cur->MutableOutputDefs()[1]);
|
||||
PushAllOutputNode(graph, to_visit, cur, visited);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain)) {
|
||||
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end());
|
||||
subgraph.insert(cur->MutableOutputDefs()[0]);
|
||||
PushAllOutputNode(graph, to_visit, cur, visited);
|
||||
|
|
@ -361,6 +369,31 @@ void IterateSubgraphFromNode(Graph& graph,
|
|||
} else {
|
||||
candidate_outputs.insert(cur);
|
||||
}
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "ReduceMean", {1, 11, 13, 18})) {
|
||||
if (cur->InputDefs()[0]->Shape()) {
|
||||
auto axes = cur->GetAttributes().at("axes").ints();
|
||||
bool axes_check = (axes.size() > 0);
|
||||
for (int64_t axis : axes) {
|
||||
axis = axis < 0 ? axis + cur->InputDefs()[0]->Shape()->dim_size() : axis;
|
||||
if (axis < 2) {
|
||||
LOG_DEBUG_INFO(logger, "PaddingElimination::axis of ReduceMean: " + cur->Name() + " is " +
|
||||
std::to_string(axis) + ", which blocks merging leading two dims.");
|
||||
axes_check = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (axes_check) {
|
||||
LOG_DEBUG_INFO(logger, "PaddingElimination::ReduceMean: " + cur->Name() + " is added to subgraph.");
|
||||
subgraph.insert(cur->MutableOutputDefs()[0]);
|
||||
PushAllOutputNode(graph, to_visit, cur, visited);
|
||||
} else {
|
||||
candidate_outputs.insert(cur);
|
||||
}
|
||||
} else {
|
||||
LOG_DEBUG_INFO(logger, "PaddingElimination::shape of input of ReduceMean: " + cur->Name() + " is unknown.");
|
||||
candidate_outputs.insert(cur);
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
candidate_outputs.insert(cur);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5758,6 +5758,9 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l
|
|||
("LayerNormalization", 0),
|
||||
("Cast", 0),
|
||||
("BiasGelu", 0),
|
||||
("Gelu", 0),
|
||||
("ReduceMean", 0),
|
||||
("ReduceMean", 1),
|
||||
],
|
||||
)
|
||||
def test_ops_for_padding_elimination(test_cases):
|
||||
|
|
@ -5775,8 +5778,8 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
|
||||
# test test_elementwise op for padding elimination
|
||||
# in case 0, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [hidden_size],
|
||||
# the test_op should be included in padding elimination subgraph and the GatherGrad should be added to
|
||||
# output of test_op.
|
||||
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be
|
||||
# added to output of test_op.
|
||||
# in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
|
||||
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
|
||||
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
|
||||
|
|
@ -5784,7 +5787,7 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
|
||||
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
|
||||
# in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
|
||||
# the test_op should be included in padding elimination subgraph and the GatherGrad should be added to
|
||||
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to
|
||||
# output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
|
||||
# flatten and elimination padding.
|
||||
def test_elementwise(self, input_ids):
|
||||
|
|
@ -5811,25 +5814,25 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
return output
|
||||
|
||||
# test MatMul op for padding elimination
|
||||
# in case 0, the shapes of inputs of MatMul are [2, seqlen] and [batch_size, seqlen, hidden_size]
|
||||
# in case 0, the shapes of inputs of MatMul are [batch_size, seqlen, hidden_size] and [hidden_size, 128]
|
||||
# the MatMul should be included in padding elimination subgraph and the PadAndUnflatten should be
|
||||
# added to output of MatMul.
|
||||
# in case 1, the shapes of inputs of MatMul are [2, seqlen] and [batch_size, seqlen, hidden_size]
|
||||
# this case is not support in padding elimination, so the MatMul should not be included in padding
|
||||
# elimination subgraph and the GatherGrad should be added before MatMul.
|
||||
# in case 1, the shapes of inputs of MatMul are [batch_size, seqlen, hidden_size] and [hidden_size, 128]
|
||||
# the MatMul should be included in padding elimination subgraph and the GatherGrad should be added to
|
||||
# output of MatMul.
|
||||
# elimination subgraph and the PadAndUnflatten should be added before MatMul.
|
||||
def test_matmul(self, input_ids):
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
output = None
|
||||
if case == 0:
|
||||
matmul_input = torch.randn((2, input_ids.size(1))).to(device)
|
||||
output = torch.matmul(matmul_input, inputs_embeds)
|
||||
elif case == 1:
|
||||
matmul_input = torch.randn((self.hidden_size, 128)).to(device)
|
||||
output = torch.matmul(inputs_embeds, matmul_input)
|
||||
elif case == 1:
|
||||
matmul_input = torch.randn((2, input_ids.size(1))).to(device)
|
||||
output = torch.matmul(matmul_input, inputs_embeds)
|
||||
return output
|
||||
|
||||
# test other ops for padding elimination
|
||||
# all these ops should be included in padding elimination subgraph and the GatherGrad should be added to
|
||||
# all these ops should be included in padding elimination subgraph and the PadAndUnflatten should be added to
|
||||
# output of these ops.
|
||||
def test_other(self, input_ids):
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
|
@ -5843,6 +5846,18 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
elif test_op == "BiasGelu":
|
||||
bias = torch.randn((self.hidden_size,)).to(device)
|
||||
output = torch.nn.functional.gelu(inputs_embeds + bias)
|
||||
elif test_op == "Gelu":
|
||||
output = torch.nn.functional.gelu(inputs_embeds)
|
||||
elif test_op == "ReduceMean":
|
||||
# In case 0, the inputs_embeds are reduced at last dimension, the ReduceMean should be included in padding
|
||||
# elimination subgraph and the PadAndUnflatten should be added to output of ReduceMean.
|
||||
# In case 1, the inputs_embeds are reduced at first dimension which is not supported in padding elimination,
|
||||
# so the ReduceMean should not be included in padding elimination subgraph and the PadAndUnflatten should
|
||||
# be added before ReduceMean.
|
||||
if case == 0:
|
||||
output = torch.mean(inputs_embeds, dim=-1)
|
||||
elif case == 1:
|
||||
output = torch.mean(inputs_embeds, dim=0)
|
||||
return output
|
||||
|
||||
def forward(self, input_ids):
|
||||
|
|
@ -5900,13 +5915,11 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
gathergrad_input_optypes = [find_input_node_type(training_model, arg) for arg in gathergrad_node.input]
|
||||
if test_op == "Add" or test_op == "Mul" or test_op == "Sub":
|
||||
assert test_op in gathergrad_input_optypes
|
||||
elif test_op == "MatMul":
|
||||
if case == 0:
|
||||
assert "ATen" in gathergrad_input_optypes
|
||||
elif case == 1:
|
||||
assert "MatMul" in gathergrad_input_optypes
|
||||
else:
|
||||
assert test_op in gathergrad_input_optypes
|
||||
if case == 0:
|
||||
assert test_op in gathergrad_input_optypes
|
||||
else:
|
||||
assert "ATen" in gathergrad_input_optypes
|
||||
|
||||
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue