Support Mul and Sub in padding elimination (#16478)

### Description
Support Mul and Sub in padding elimination
This commit is contained in:
guyang3532 2023-06-27 07:43:29 +08:00 committed by GitHub
parent 4a331ef667
commit eb4e6d2062
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 37 deletions

View file

@ -292,7 +292,9 @@ void IterateSubgraphFromNode(Graph& graph,
to_visit.pop();
visited.insert(cur);
if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Add", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "BiasGelu", {1}, kMSDomain)) {
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "BiasGelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sub", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Mul", {7, 13, 14})) {
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end() ||
subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end());
NodeArg* arg_in_subgraph = nullptr;

View file

@ -5752,10 +5752,16 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l
("Add", 0),
("Add", 1),
("Add", 2),
("Sub", 0),
("Sub", 1),
("Sub", 2),
("Mul", 0),
("Mul", 1),
("Mul", 2),
("MatMul", 0),
("MatMul", 1),
("Dropout", 0),
("LayerNorm", 0),
("LayerNormalization", 0),
("Cast", 0),
("BiasGelu", 0),
],
@ -5764,38 +5770,47 @@ def test_ops_for_padding_elimination(test_cases):
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"
test_op = test_cases[0]
case = test_cases[1]
# test_op = "Sub"
# case = 2
class ToyModel(torch.nn.Module):
def __init__(self, vocab_size, hidden_size, pad_token_id):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
if test_op == "LayerNorm":
if test_op == "LayerNormalization":
self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05)
self.hidden_size = hidden_size
# test Add op for padding elimination
# in case 0, the shapes of inputs of Add are [batch_size, seqlen, hidden_size] and [hidden_size],
# the Add should be included in padding elimination subgraph and the GatherGrad should be added to
# output of Add.
# in case 1, the shapes of inputs of Add are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
# this case is not support in padding elimination, so the Add should not be included in padding
# elimination subgraph and the GatherGrad should be added before Add.
# in case 2, the shapes of inputs of Add are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
# the Add should be included in padding elimination subgraph and the GatherGrad should be added to
# output of Add. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
# test test_elementwise op for padding elimination
# in case 0, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [hidden_size],
# the test_op should be included in padding elimination subgraph and the GatherGrad should be added to
# output of test_op.
# in case 1, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
# this case is not support in padding elimination, so the test_op should not be included in padding
# elimination subgraph and the GatherGrad should be added before test_op.
# in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
# the test_op should be included in padding elimination subgraph and the GatherGrad should be added to
# output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
# flatten and elimination padding.
def test_add(self, input_ids):
def test_elementwise(self, input_ids):
input_shape = input_ids.size()
add_input = None
one_input = None
if case == 0:
add_input = torch.ones(self.hidden_size, dtype=torch.long).to(device)
one_input = torch.ones(self.hidden_size, dtype=torch.long).to(device)
elif case == 1:
add_input = torch.ones((input_shape[0], 1, self.hidden_size), dtype=torch.long).to(device)
one_input = torch.ones((input_shape[0], 1, self.hidden_size), dtype=torch.long).to(device)
elif case == 2:
add_input = torch.ones(input_shape, dtype=torch.long).to(device)
add_input = add_input.unsqueeze(-1).expand(-1, -1, self.hidden_size)
one_input = torch.ones(input_shape, dtype=torch.long).to(device)
one_input = one_input.unsqueeze(-1).expand(-1, -1, self.hidden_size)
inputs_embeds = self.word_embeddings(input_ids)
output = add_input + inputs_embeds
if test_op == "Add":
output = one_input + inputs_embeds
elif test_op == "Sub":
output = one_input - inputs_embeds
elif test_op == "Mul":
output = one_input * inputs_embeds
else:
output = None
return output
# test MatMul op for padding elimination
@ -5824,7 +5839,7 @@ def test_ops_for_padding_elimination(test_cases):
output = None
if test_op == "Dropout":
output = torch.nn.functional.dropout(inputs_embeds, p=0.5, training=True)
elif test_op == "LayerNorm":
elif test_op == "LayerNormalization":
output = self.LayerNorm(inputs_embeds)
elif test_op == "Cast":
output = inputs_embeds.to(torch.float16)
@ -5834,8 +5849,8 @@ def test_ops_for_padding_elimination(test_cases):
return output
def forward(self, input_ids):
if test_op == "Add":
output = self.test_add(input_ids)
if test_op in ["Add", "Mul", "Sub"]:
output = self.test_elementwise(input_ids)
elif test_op == "MatMul":
output = self.test_matmul(input_ids)
else:
@ -5865,7 +5880,10 @@ def test_ops_for_padding_elimination(test_cases):
model(x)
training_model = model._torch_module._execution_manager(True)._onnx_models.optimized_model
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1
if test_op == "Sub":
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 2
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "NonZero"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "GatherGrad"]) == 1
@ -5874,25 +5892,27 @@ def test_ops_for_padding_elimination(test_cases):
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
gathergrad_node = [node for node in training_model.graph.node if node.op_type == "GatherGrad"][0]
if test_op == "Add":
def find_input_node_type(model, arg):
result = []
for node in model.graph.node:
if arg in node.output:
result.append(node)
return result[0].op_type if len(result) == 1 else None
gathergrad_input_optypes = [find_input_node_type(training_model, arg) for arg in gathergrad_node.input]
if test_op == "Add" or test_op == "Mul" or test_op == "Sub":
if case == 0:
assert "/_original_module/Add_output_0" in gathergrad_node.input
assert test_op in gathergrad_input_optypes
elif case == 1:
assert "/_original_module/word_embeddings/ATen_output_0" in gathergrad_node.input
assert "ATen" in gathergrad_input_optypes
elif test_op == "MatMul":
if case == 0:
assert "/_original_module/word_embeddings/ATen_output_0" in gathergrad_node.input
assert "ATen" in gathergrad_input_optypes
elif case == 1:
assert "/_original_module/MatMul_output_0" in gathergrad_node.input
assert "MatMul" in gathergrad_input_optypes
else:
if test_op == "Dropout":
assert "/_original_module/Dropout_output_0" in gathergrad_node.input
elif test_op == "LayerNorm":
assert "/_original_module/LayerNorm/Add_1_output_0" in gathergrad_node.input
elif test_op == "Cast":
assert "/_original_module/Cast_output_0" in gathergrad_node.input
elif test_op == "BiasGelu":
assert "/_original_module/Mul_1_output_0" in gathergrad_node.input
assert test_op in gathergrad_input_optypes
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]