mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Support Mul and Sub in padding elimination (#16478)
### Description Support Mul and Sub in padding elimination
This commit is contained in:
parent
4a331ef667
commit
eb4e6d2062
2 changed files with 59 additions and 37 deletions
|
|
@ -292,7 +292,9 @@ void IterateSubgraphFromNode(Graph& graph,
|
|||
to_visit.pop();
|
||||
visited.insert(cur);
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Add", {7, 13, 14}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "BiasGelu", {1}, kMSDomain)) {
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "BiasGelu", {1}, kMSDomain) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sub", {7, 13, 14}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Mul", {7, 13, 14})) {
|
||||
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end() ||
|
||||
subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end());
|
||||
NodeArg* arg_in_subgraph = nullptr;
|
||||
|
|
|
|||
|
|
@ -5752,10 +5752,16 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l
|
|||
("Add", 0),
|
||||
("Add", 1),
|
||||
("Add", 2),
|
||||
("Sub", 0),
|
||||
("Sub", 1),
|
||||
("Sub", 2),
|
||||
("Mul", 0),
|
||||
("Mul", 1),
|
||||
("Mul", 2),
|
||||
("MatMul", 0),
|
||||
("MatMul", 1),
|
||||
("Dropout", 0),
|
||||
("LayerNorm", 0),
|
||||
("LayerNormalization", 0),
|
||||
("Cast", 0),
|
||||
("BiasGelu", 0),
|
||||
],
|
||||
|
|
@ -5764,38 +5770,47 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"
|
||||
test_op = test_cases[0]
|
||||
case = test_cases[1]
|
||||
# test_op = "Sub"
|
||||
# case = 2
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self, vocab_size, hidden_size, pad_token_id):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
|
||||
if test_op == "LayerNorm":
|
||||
if test_op == "LayerNormalization":
|
||||
self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# test Add op for padding elimination
|
||||
# in case 0, the shapes of inputs of Add are [batch_size, seqlen, hidden_size] and [hidden_size],
|
||||
# the Add should be included in padding elimination subgraph and the GatherGrad should be added to
|
||||
# output of Add.
|
||||
# in case 1, the shapes of inputs of Add are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
|
||||
# this case is not support in padding elimination, so the Add should not be included in padding
|
||||
# elimination subgraph and the GatherGrad should be added before Add.
|
||||
# in case 2, the shapes of inputs of Add are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
|
||||
# the Add should be included in padding elimination subgraph and the GatherGrad should be added to
|
||||
# output of Add. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
|
||||
# test test_elementwise op for padding elimination
|
||||
# in case 0, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [hidden_size],
|
||||
# the test_op should be included in padding elimination subgraph and the GatherGrad should be added to
|
||||
# output of test_op.
|
||||
# in case 1, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
|
||||
# this case is not support in padding elimination, so the test_op should not be included in padding
|
||||
# elimination subgraph and the GatherGrad should be added before test_op.
|
||||
# in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
|
||||
# the test_op should be included in padding elimination subgraph and the GatherGrad should be added to
|
||||
# output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
|
||||
# flatten and elimination padding.
|
||||
def test_add(self, input_ids):
|
||||
def test_elementwise(self, input_ids):
|
||||
input_shape = input_ids.size()
|
||||
add_input = None
|
||||
one_input = None
|
||||
if case == 0:
|
||||
add_input = torch.ones(self.hidden_size, dtype=torch.long).to(device)
|
||||
one_input = torch.ones(self.hidden_size, dtype=torch.long).to(device)
|
||||
elif case == 1:
|
||||
add_input = torch.ones((input_shape[0], 1, self.hidden_size), dtype=torch.long).to(device)
|
||||
one_input = torch.ones((input_shape[0], 1, self.hidden_size), dtype=torch.long).to(device)
|
||||
elif case == 2:
|
||||
add_input = torch.ones(input_shape, dtype=torch.long).to(device)
|
||||
add_input = add_input.unsqueeze(-1).expand(-1, -1, self.hidden_size)
|
||||
one_input = torch.ones(input_shape, dtype=torch.long).to(device)
|
||||
one_input = one_input.unsqueeze(-1).expand(-1, -1, self.hidden_size)
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
output = add_input + inputs_embeds
|
||||
if test_op == "Add":
|
||||
output = one_input + inputs_embeds
|
||||
elif test_op == "Sub":
|
||||
output = one_input - inputs_embeds
|
||||
elif test_op == "Mul":
|
||||
output = one_input * inputs_embeds
|
||||
else:
|
||||
output = None
|
||||
return output
|
||||
|
||||
# test MatMul op for padding elimination
|
||||
|
|
@ -5824,7 +5839,7 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
output = None
|
||||
if test_op == "Dropout":
|
||||
output = torch.nn.functional.dropout(inputs_embeds, p=0.5, training=True)
|
||||
elif test_op == "LayerNorm":
|
||||
elif test_op == "LayerNormalization":
|
||||
output = self.LayerNorm(inputs_embeds)
|
||||
elif test_op == "Cast":
|
||||
output = inputs_embeds.to(torch.float16)
|
||||
|
|
@ -5834,8 +5849,8 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
return output
|
||||
|
||||
def forward(self, input_ids):
|
||||
if test_op == "Add":
|
||||
output = self.test_add(input_ids)
|
||||
if test_op in ["Add", "Mul", "Sub"]:
|
||||
output = self.test_elementwise(input_ids)
|
||||
elif test_op == "MatMul":
|
||||
output = self.test_matmul(input_ids)
|
||||
else:
|
||||
|
|
@ -5865,7 +5880,10 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
model(x)
|
||||
|
||||
training_model = model._torch_module._execution_manager(True)._onnx_models.optimized_model
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1
|
||||
if test_op == "Sub":
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 2
|
||||
else:
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "NonZero"]) == 1
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "GatherGrad"]) == 1
|
||||
|
|
@ -5874,25 +5892,27 @@ def test_ops_for_padding_elimination(test_cases):
|
|||
else:
|
||||
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
|
||||
gathergrad_node = [node for node in training_model.graph.node if node.op_type == "GatherGrad"][0]
|
||||
if test_op == "Add":
|
||||
|
||||
def find_input_node_type(model, arg):
|
||||
result = []
|
||||
for node in model.graph.node:
|
||||
if arg in node.output:
|
||||
result.append(node)
|
||||
return result[0].op_type if len(result) == 1 else None
|
||||
|
||||
gathergrad_input_optypes = [find_input_node_type(training_model, arg) for arg in gathergrad_node.input]
|
||||
if test_op == "Add" or test_op == "Mul" or test_op == "Sub":
|
||||
if case == 0:
|
||||
assert "/_original_module/Add_output_0" in gathergrad_node.input
|
||||
assert test_op in gathergrad_input_optypes
|
||||
elif case == 1:
|
||||
assert "/_original_module/word_embeddings/ATen_output_0" in gathergrad_node.input
|
||||
assert "ATen" in gathergrad_input_optypes
|
||||
elif test_op == "MatMul":
|
||||
if case == 0:
|
||||
assert "/_original_module/word_embeddings/ATen_output_0" in gathergrad_node.input
|
||||
assert "ATen" in gathergrad_input_optypes
|
||||
elif case == 1:
|
||||
assert "/_original_module/MatMul_output_0" in gathergrad_node.input
|
||||
assert "MatMul" in gathergrad_input_optypes
|
||||
else:
|
||||
if test_op == "Dropout":
|
||||
assert "/_original_module/Dropout_output_0" in gathergrad_node.input
|
||||
elif test_op == "LayerNorm":
|
||||
assert "/_original_module/LayerNorm/Add_1_output_0" in gathergrad_node.input
|
||||
elif test_op == "Cast":
|
||||
assert "/_original_module/Cast_output_0" in gathergrad_node.input
|
||||
elif test_op == "BiasGelu":
|
||||
assert "/_original_module/Mul_1_output_0" in gathergrad_node.input
|
||||
assert test_op in gathergrad_input_optypes
|
||||
|
||||
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue