Add support for more ops for padding elimination (#17217)

Add support for Gelu/ReduceMean/SimplifiedLayerNormalization for padding
elimination
This commit is contained in:
guyang3532 2023-08-25 18:02:15 +08:00 committed by GitHub
parent 735cc8e6c8
commit 401129d484
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 20 deletions

View file

@ -300,13 +300,20 @@ void IterateSubgraphFromNode(Graph& graph,
candidate_outputs.insert(cur);
continue;
}
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "LayerNormalization", {1, 17}, kOnnxDomain)) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "LayerNormalization", {1, 17}, kOnnxDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "SimplifiedLayerNormalization", {1}, kOnnxDomain)) {
if (subgraph.find(cur->MutableInputDefs()[0]) == subgraph.end()) {
LOG_DEBUG_INFO(logger, "PaddingElimination::First input of Normalization: " + cur->Name() +
" is not in subgraph.");
candidate_outputs.insert(cur);
continue;
}
if (!cur->InputDefs()[0]->Shape()) {
LOG_DEBUG_INFO(logger, "PaddingElimination::First input of Normalization: " + cur->Name() +
" has no shape.");
candidate_outputs.insert(cur);
continue;
}
auto axis = static_cast<int64_t>(cur->GetAttributes().at("axis").i());
axis = axis < 0 ? axis + cur->InputDefs()[0]->Shape()->dim_size() : axis;
if (axis < 2) {
@ -322,7 +329,8 @@ void IterateSubgraphFromNode(Graph& graph,
subgraph.insert(cur->MutableOutputDefs()[0]);
subgraph.insert(cur->MutableOutputDefs()[1]);
PushAllOutputNode(graph, to_visit, cur, visited);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain)) {
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end());
subgraph.insert(cur->MutableOutputDefs()[0]);
PushAllOutputNode(graph, to_visit, cur, visited);
@ -361,6 +369,31 @@ void IterateSubgraphFromNode(Graph& graph,
} else {
candidate_outputs.insert(cur);
}
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "ReduceMean", {1, 11, 13, 18})) {
if (cur->InputDefs()[0]->Shape()) {
auto axes = cur->GetAttributes().at("axes").ints();
bool axes_check = (axes.size() > 0);
for (int64_t axis : axes) {
axis = axis < 0 ? axis + cur->InputDefs()[0]->Shape()->dim_size() : axis;
if (axis < 2) {
LOG_DEBUG_INFO(logger, "PaddingElimination::axis of ReduceMean: " + cur->Name() + " is " +
std::to_string(axis) + ", which blocks merging leading two dims.");
axes_check = false;
break;
}
}
if (axes_check) {
LOG_DEBUG_INFO(logger, "PaddingElimination::ReduceMean: " + cur->Name() + " is added to subgraph.");
subgraph.insert(cur->MutableOutputDefs()[0]);
PushAllOutputNode(graph, to_visit, cur, visited);
} else {
candidate_outputs.insert(cur);
}
} else {
LOG_DEBUG_INFO(logger, "PaddingElimination::shape of input of ReduceMean: " + cur->Name() + " is unknown.");
candidate_outputs.insert(cur);
continue;
}
} else {
candidate_outputs.insert(cur);
}

View file

@ -5758,6 +5758,9 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l
("LayerNormalization", 0),
("Cast", 0),
("BiasGelu", 0),
("Gelu", 0),
("ReduceMean", 0),
("ReduceMean", 1),
],
)
def test_ops_for_padding_elimination(test_cases):
@ -5775,8 +5778,8 @@ def test_ops_for_padding_elimination(test_cases):
# test test_elementwise op for padding elimination
# in case 0, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [hidden_size],
# the test_op should be included in padding elimination subgraph and the GatherGrad should be added to
# output of test_op.
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be
# added to output of test_op.
# in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
@ -5784,7 +5787,7 @@ def test_ops_for_padding_elimination(test_cases):
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
# the test_op should be included in padding elimination subgraph and the GatherGrad should be added to
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to
# output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
# flatten and elimination padding.
def test_elementwise(self, input_ids):
@ -5811,25 +5814,25 @@ def test_ops_for_padding_elimination(test_cases):
return output
# test MatMul op for padding elimination
# in case 0, the shapes of inputs of MatMul are [2, seqlen] and [batch_size, seqlen, hidden_size]
# in case 0, the shapes of inputs of MatMul are [batch_size, seqlen, hidden_size] and [hidden_size, 128]
# the MatMul should be included in padding elimination subgraph and the PadAndUnflatten should be
# added to output of MatMul.
# in case 1, the shapes of inputs of MatMul are [2, seqlen] and [batch_size, seqlen, hidden_size]
# this case is not support in padding elimination, so the MatMul should not be included in padding
# elimination subgraph and the GatherGrad should be added before MatMul.
# in case 1, the shapes of inputs of MatMul are [batch_size, seqlen, hidden_size] and [hidden_size, 128]
# the MatMul should be included in padding elimination subgraph and the GatherGrad should be added to
# output of MatMul.
# elimination subgraph and the PadAndUnflatten should be added before MatMul.
def test_matmul(self, input_ids):
inputs_embeds = self.word_embeddings(input_ids)
output = None
if case == 0:
matmul_input = torch.randn((2, input_ids.size(1))).to(device)
output = torch.matmul(matmul_input, inputs_embeds)
elif case == 1:
matmul_input = torch.randn((self.hidden_size, 128)).to(device)
output = torch.matmul(inputs_embeds, matmul_input)
elif case == 1:
matmul_input = torch.randn((2, input_ids.size(1))).to(device)
output = torch.matmul(matmul_input, inputs_embeds)
return output
# test other ops for padding elimination
# all these ops should be included in padding elimination subgraph and the GatherGrad should be added to
# all these ops should be included in padding elimination subgraph and the PadAndUnflatten should be added to
# output of these ops.
def test_other(self, input_ids):
inputs_embeds = self.word_embeddings(input_ids)
@ -5843,6 +5846,18 @@ def test_ops_for_padding_elimination(test_cases):
elif test_op == "BiasGelu":
bias = torch.randn((self.hidden_size,)).to(device)
output = torch.nn.functional.gelu(inputs_embeds + bias)
elif test_op == "Gelu":
output = torch.nn.functional.gelu(inputs_embeds)
elif test_op == "ReduceMean":
# In case 0, the inputs_embeds are reduced at last dimension, the ReduceMean should be included in padding
# elimination subgraph and the PadAndUnflatten should be added to output of ReduceMean.
# In case 1, the inputs_embeds are reduced at first dimension which is not supported in padding elimination,
# so the ReduceMean should not be included in padding elimination subgraph and the PadAndUnflatten should
# be added before ReduceMean.
if case == 0:
output = torch.mean(inputs_embeds, dim=-1)
elif case == 1:
output = torch.mean(inputs_embeds, dim=0)
return output
def forward(self, input_ids):
@ -5900,13 +5915,11 @@ def test_ops_for_padding_elimination(test_cases):
gathergrad_input_optypes = [find_input_node_type(training_model, arg) for arg in gathergrad_node.input]
if test_op == "Add" or test_op == "Mul" or test_op == "Sub":
assert test_op in gathergrad_input_optypes
elif test_op == "MatMul":
if case == 0:
assert "ATen" in gathergrad_input_optypes
elif case == 1:
assert "MatMul" in gathergrad_input_optypes
else:
assert test_op in gathergrad_input_optypes
if case == 0:
assert test_op in gathergrad_input_optypes
else:
assert "ATen" in gathergrad_input_optypes
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]