diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 5950157acc..fe2f86edd0 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -292,7 +292,9 @@ void IterateSubgraphFromNode(Graph& graph, to_visit.pop(); visited.insert(cur); if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Add", {7, 13, 14}) || - graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "BiasGelu", {1}, kMSDomain)) { + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "BiasGelu", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sub", {7, 13, 14}) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Mul", {7, 13, 14})) { ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end() || subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end()); NodeArg* arg_in_subgraph = nullptr; diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index bc21d89f5f..f7d1c17beb 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5752,10 +5752,16 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l ("Add", 0), ("Add", 1), ("Add", 2), + ("Sub", 0), + ("Sub", 1), + ("Sub", 2), + ("Mul", 0), + ("Mul", 1), + ("Mul", 2), ("MatMul", 0), ("MatMul", 1), ("Dropout", 0), - ("LayerNorm", 0), + ("LayerNormalization", 0), ("Cast", 0), ("BiasGelu", 0), ], @@ -5764,38 +5770,47 @@ def test_ops_for_padding_elimination(test_cases): os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1" test_op = test_cases[0] case = test_cases[1] + # test_op = "Sub" + # case = 2 class ToyModel(torch.nn.Module): def __init__(self, vocab_size, hidden_size, pad_token_id): super().__init__() self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) - if test_op == "LayerNorm": + if test_op == "LayerNormalization": self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05) self.hidden_size = hidden_size - # test Add op for padding elimination - # in case 0, the shapes of inputs of Add are [batch_size, seqlen, hidden_size] and [hidden_size], - # the Add should be included in padding elimination subgraph and the GatherGrad should be added to - # output of Add. - # in case 1, the shapes of inputs of Add are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size], - # this case is not support in padding elimination, so the Add should not be included in padding - # elimination subgraph and the GatherGrad should be added before Add. - # in case 2, the shapes of inputs of Add are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size], - # the Add should be included in padding elimination subgraph and the GatherGrad should be added to - # output of Add. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to + # test test_elementwise op for padding elimination + # in case 0, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [hidden_size], + # the test_op should be included in padding elimination subgraph and the GatherGrad should be added to + # output of test_op. + # in case 1, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size], + # this case is not support in padding elimination, so the test_op should not be included in padding + # elimination subgraph and the GatherGrad should be added before test_op. + # in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size], + # the test_op should be included in padding elimination subgraph and the GatherGrad should be added to + # output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to # flatten and elimination padding. - def test_add(self, input_ids): + def test_elementwise(self, input_ids): input_shape = input_ids.size() - add_input = None + one_input = None if case == 0: - add_input = torch.ones(self.hidden_size, dtype=torch.long).to(device) + one_input = torch.ones(self.hidden_size, dtype=torch.long).to(device) elif case == 1: - add_input = torch.ones((input_shape[0], 1, self.hidden_size), dtype=torch.long).to(device) + one_input = torch.ones((input_shape[0], 1, self.hidden_size), dtype=torch.long).to(device) elif case == 2: - add_input = torch.ones(input_shape, dtype=torch.long).to(device) - add_input = add_input.unsqueeze(-1).expand(-1, -1, self.hidden_size) + one_input = torch.ones(input_shape, dtype=torch.long).to(device) + one_input = one_input.unsqueeze(-1).expand(-1, -1, self.hidden_size) inputs_embeds = self.word_embeddings(input_ids) - output = add_input + inputs_embeds + if test_op == "Add": + output = one_input + inputs_embeds + elif test_op == "Sub": + output = one_input - inputs_embeds + elif test_op == "Mul": + output = one_input * inputs_embeds + else: + output = None return output # test MatMul op for padding elimination @@ -5824,7 +5839,7 @@ def test_ops_for_padding_elimination(test_cases): output = None if test_op == "Dropout": output = torch.nn.functional.dropout(inputs_embeds, p=0.5, training=True) - elif test_op == "LayerNorm": + elif test_op == "LayerNormalization": output = self.LayerNorm(inputs_embeds) elif test_op == "Cast": output = inputs_embeds.to(torch.float16) @@ -5834,8 +5849,8 @@ def test_ops_for_padding_elimination(test_cases): return output def forward(self, input_ids): - if test_op == "Add": - output = self.test_add(input_ids) + if test_op in ["Add", "Mul", "Sub"]: + output = self.test_elementwise(input_ids) elif test_op == "MatMul": output = self.test_matmul(input_ids) else: @@ -5865,7 +5880,10 @@ def test_ops_for_padding_elimination(test_cases): model(x) training_model = model._torch_module._execution_manager(True)._onnx_models.optimized_model - assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1 + if test_op == "Sub": + assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 2 + else: + assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1 assert len([node.op_type for node in training_model.graph.node if node.op_type == "NonZero"]) == 1 assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1 assert len([node.op_type for node in training_model.graph.node if node.op_type == "GatherGrad"]) == 1 @@ -5874,25 +5892,27 @@ def test_ops_for_padding_elimination(test_cases): else: assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1 gathergrad_node = [node for node in training_model.graph.node if node.op_type == "GatherGrad"][0] - if test_op == "Add": + + def find_input_node_type(model, arg): + result = [] + for node in model.graph.node: + if arg in node.output: + result.append(node) + return result[0].op_type if len(result) == 1 else None + + gathergrad_input_optypes = [find_input_node_type(training_model, arg) for arg in gathergrad_node.input] + if test_op == "Add" or test_op == "Mul" or test_op == "Sub": if case == 0: - assert "/_original_module/Add_output_0" in gathergrad_node.input + assert test_op in gathergrad_input_optypes elif case == 1: - assert "/_original_module/word_embeddings/ATen_output_0" in gathergrad_node.input + assert "ATen" in gathergrad_input_optypes elif test_op == "MatMul": if case == 0: - assert "/_original_module/word_embeddings/ATen_output_0" in gathergrad_node.input + assert "ATen" in gathergrad_input_optypes elif case == 1: - assert "/_original_module/MatMul_output_0" in gathergrad_node.input + assert "MatMul" in gathergrad_input_optypes else: - if test_op == "Dropout": - assert "/_original_module/Dropout_output_0" in gathergrad_node.input - elif test_op == "LayerNorm": - assert "/_original_module/LayerNorm/Add_1_output_0" in gathergrad_node.input - elif test_op == "Cast": - assert "/_original_module/Cast_output_0" in gathergrad_node.input - elif test_op == "BiasGelu": - assert "/_original_module/Mul_1_output_0" in gathergrad_node.input + assert test_op in gathergrad_input_optypes del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]