From 401129d4848c302eff995f08725ca9aa5daef15c Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Fri, 25 Aug 2023 18:02:15 +0800 Subject: [PATCH] Add support for more ops for padding elimination (#17217) Add support for Gelu/ReduceMean/SimplifiedLayerNormalization for padding elimination --- .../compute_optimizer/padding_elimination.cc | 37 +++++++++++++- .../python/orttraining_test_ortmodule_api.py | 49 ++++++++++++------- 2 files changed, 66 insertions(+), 20 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 758ec7b8eb..74247c059c 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -300,13 +300,20 @@ void IterateSubgraphFromNode(Graph& graph, candidate_outputs.insert(cur); continue; } - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "LayerNormalization", {1, 17}, kOnnxDomain)) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "LayerNormalization", {1, 17}, kOnnxDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "SimplifiedLayerNormalization", {1}, kOnnxDomain)) { if (subgraph.find(cur->MutableInputDefs()[0]) == subgraph.end()) { LOG_DEBUG_INFO(logger, "PaddingElimination::First input of Normalization: " + cur->Name() + " is not in subgraph."); candidate_outputs.insert(cur); continue; } + if (!cur->InputDefs()[0]->Shape()) { + LOG_DEBUG_INFO(logger, "PaddingElimination::First input of Normalization: " + cur->Name() + + " has no shape."); + candidate_outputs.insert(cur); + continue; + } auto axis = static_cast(cur->GetAttributes().at("axis").i()); axis = axis < 0 ? axis + cur->InputDefs()[0]->Shape()->dim_size() : axis; if (axis < 2) { @@ -322,7 +329,8 @@ void IterateSubgraphFromNode(Graph& graph, subgraph.insert(cur->MutableOutputDefs()[0]); subgraph.insert(cur->MutableOutputDefs()[1]); PushAllOutputNode(graph, to_visit, cur, visited); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain)) { ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()); subgraph.insert(cur->MutableOutputDefs()[0]); PushAllOutputNode(graph, to_visit, cur, visited); @@ -361,6 +369,31 @@ void IterateSubgraphFromNode(Graph& graph, } else { candidate_outputs.insert(cur); } + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "ReduceMean", {1, 11, 13, 18})) { + if (cur->InputDefs()[0]->Shape()) { + auto axes = cur->GetAttributes().at("axes").ints(); + bool axes_check = (axes.size() > 0); + for (int64_t axis : axes) { + axis = axis < 0 ? axis + cur->InputDefs()[0]->Shape()->dim_size() : axis; + if (axis < 2) { + LOG_DEBUG_INFO(logger, "PaddingElimination::axis of ReduceMean: " + cur->Name() + " is " + + std::to_string(axis) + ", which blocks merging leading two dims."); + axes_check = false; + break; + } + } + if (axes_check) { + LOG_DEBUG_INFO(logger, "PaddingElimination::ReduceMean: " + cur->Name() + " is added to subgraph."); + subgraph.insert(cur->MutableOutputDefs()[0]); + PushAllOutputNode(graph, to_visit, cur, visited); + } else { + candidate_outputs.insert(cur); + } + } else { + LOG_DEBUG_INFO(logger, "PaddingElimination::shape of input of ReduceMean: " + cur->Name() + " is unknown."); + candidate_outputs.insert(cur); + continue; + } } else { candidate_outputs.insert(cur); } diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 0a398bd7b4..64cdb957f4 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5758,6 +5758,9 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l ("LayerNormalization", 0), ("Cast", 0), ("BiasGelu", 0), + ("Gelu", 0), + ("ReduceMean", 0), + ("ReduceMean", 1), ], ) def test_ops_for_padding_elimination(test_cases): @@ -5775,8 +5778,8 @@ def test_ops_for_padding_elimination(test_cases): # test test_elementwise op for padding elimination # in case 0, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [hidden_size], - # the test_op should be included in padding elimination subgraph and the GatherGrad should be added to - # output of test_op. + # the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be + # added to output of test_op. # in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size], # the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather' # pattern should be insert to the arg of [batch_size, 1, hidden_size]. @@ -5784,7 +5787,7 @@ def test_ops_for_padding_elimination(test_cases): # the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather' # pattern should be insert to the arg of [batch_size, 1, hidden_size]. # in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size], - # the test_op should be included in padding elimination subgraph and the GatherGrad should be added to + # the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to # output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to # flatten and elimination padding. def test_elementwise(self, input_ids): @@ -5811,25 +5814,25 @@ def test_ops_for_padding_elimination(test_cases): return output # test MatMul op for padding elimination - # in case 0, the shapes of inputs of MatMul are [2, seqlen] and [batch_size, seqlen, hidden_size] + # in case 0, the shapes of inputs of MatMul are [batch_size, seqlen, hidden_size] and [hidden_size, 128] + # the MatMul should be included in padding elimination subgraph and the PadAndUnflatten should be + # added to output of MatMul. + # in case 1, the shapes of inputs of MatMul are [2, seqlen] and [batch_size, seqlen, hidden_size] # this case is not support in padding elimination, so the MatMul should not be included in padding - # elimination subgraph and the GatherGrad should be added before MatMul. - # in case 1, the shapes of inputs of MatMul are [batch_size, seqlen, hidden_size] and [hidden_size, 128] - # the MatMul should be included in padding elimination subgraph and the GatherGrad should be added to - # output of MatMul. + # elimination subgraph and the PadAndUnflatten should be added before MatMul. def test_matmul(self, input_ids): inputs_embeds = self.word_embeddings(input_ids) output = None if case == 0: - matmul_input = torch.randn((2, input_ids.size(1))).to(device) - output = torch.matmul(matmul_input, inputs_embeds) - elif case == 1: matmul_input = torch.randn((self.hidden_size, 128)).to(device) output = torch.matmul(inputs_embeds, matmul_input) + elif case == 1: + matmul_input = torch.randn((2, input_ids.size(1))).to(device) + output = torch.matmul(matmul_input, inputs_embeds) return output # test other ops for padding elimination - # all these ops should be included in padding elimination subgraph and the GatherGrad should be added to + # all these ops should be included in padding elimination subgraph and the PadAndUnflatten should be added to # output of these ops. def test_other(self, input_ids): inputs_embeds = self.word_embeddings(input_ids) @@ -5843,6 +5846,18 @@ def test_ops_for_padding_elimination(test_cases): elif test_op == "BiasGelu": bias = torch.randn((self.hidden_size,)).to(device) output = torch.nn.functional.gelu(inputs_embeds + bias) + elif test_op == "Gelu": + output = torch.nn.functional.gelu(inputs_embeds) + elif test_op == "ReduceMean": + # In case 0, the inputs_embeds are reduced at last dimension, the ReduceMean should be included in padding + # elimination subgraph and the PadAndUnflatten should be added to output of ReduceMean. + # In case 1, the inputs_embeds are reduced at first dimension which is not supported in padding elimination, + # so the ReduceMean should not be included in padding elimination subgraph and the PadAndUnflatten should + # be added before ReduceMean. + if case == 0: + output = torch.mean(inputs_embeds, dim=-1) + elif case == 1: + output = torch.mean(inputs_embeds, dim=0) return output def forward(self, input_ids): @@ -5900,13 +5915,11 @@ def test_ops_for_padding_elimination(test_cases): gathergrad_input_optypes = [find_input_node_type(training_model, arg) for arg in gathergrad_node.input] if test_op == "Add" or test_op == "Mul" or test_op == "Sub": assert test_op in gathergrad_input_optypes - elif test_op == "MatMul": - if case == 0: - assert "ATen" in gathergrad_input_optypes - elif case == 1: - assert "MatMul" in gathergrad_input_optypes else: - assert test_op in gathergrad_input_optypes + if case == 0: + assert test_op in gathergrad_input_optypes + else: + assert "ATen" in gathergrad_input_optypes del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]