Add fusions for OpenAI CLIP (#20721)

### Description
This PR adds fusions for [OpenAI's CLIP
model](https://huggingface.co/openai/clip-vit-large-patch14-336). Here
is an example of how to run the ORT transformer optimizer for the linked
CLIP model.

```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0
```

### Motivation and Context
This PR helps optimize multi-modal models that use CLIP for the vision
encoder.
This commit is contained in:
kunal-vaishnavi 2024-05-18 08:27:16 -07:00 committed by GitHub
parent 5d07291247
commit ca22a5a9d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 167 additions and 31 deletions

View file

@ -97,7 +97,33 @@ class FusionAttentionClip(FusionAttention):
else:
# Deal with the first attention after the embedding layer.
for i in [0, 1]:
node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i)
node_before_layer_norm = None
node_before_layer_norm_1 = self.model.match_parent(normalize_node, "Add", i)
node_before_layer_norm_2 = self.model.match_parent(normalize_node, "LayerNormalization", i)
if node_before_layer_norm_1 is not None:
# Add -----------+
# | |
# LayerNorm |
# | |
# LayerNorm |
# | |
# Attention subgraph |
# | |
# SkipLayerNorm ------+
node_before_layer_norm = node_before_layer_norm_1
elif node_before_layer_norm_2 is not None:
# Add
# |
# LayerNorm --------+
# | |
# LayerNorm |
# | |
# Attention subgraph |
# | |
# SkipLayerNorm ------+
node_before_layer_norm = node_before_layer_norm_2
if node_before_layer_norm is None:
continue
child = self.model.find_first_child_by_type(
@ -130,20 +156,32 @@ class FusionAttentionClip(FusionAttention):
return
(_, _, reshape_v, add_v, matmul_v) = v_nodes
add_mask = None
add_mask_indices = []
qk_nodes = self.model.match_parent_path(
qk_nodes = None
qk_nodes_1 = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Reshape", "Add", "Reshape", "MatMul"],
[0, 0, 0, None, 0],
return_indice=add_mask_indices,
)
if qk_nodes is None:
qk_nodes_2 = self.model.match_parent_path(
matmul_qkv,
["Softmax", "MatMul"],
[0, 0],
)
if qk_nodes_1 is not None:
qk_nodes = qk_nodes_1
assert len(add_mask_indices) == 1
causal_mask_input_index = 1 - add_mask_indices[0]
(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes
elif qk_nodes_2 is not None:
qk_nodes = qk_nodes_2
(_softmax_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return
assert len(add_mask_indices) == 1
causal_mask_input_index = 1 - add_mask_indices[0]
(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes
q_nodes = self.model.match_parent_path(
matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None]
@ -172,23 +210,24 @@ class FusionAttentionClip(FusionAttention):
attention_last_node = reshape_qkv
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
# If the model is exported with batch_size == 1, there is no Concat node
if add_mask is not None:
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return
# If the model is exported with batch_size == 1, there is no Concat node
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return
new_node = self.create_attention_node(
mask_index=None,
@ -204,7 +243,7 @@ class FusionAttentionClip(FusionAttention):
output=attention_last_node.output[0],
add_qk_str=None,
scale=None,
causal=True,
causal=(add_mask is not None),
)
if new_node is None:
return

View file

@ -38,6 +38,7 @@ class FusionLayerNormalization(Fusion):
| |
+----------------------+
"""
subgraph_nodes = []
children = self.model.get_children(node, input_name_to_nodes)
if len(children) == 0 or len(children) > 2:
return
@ -53,9 +54,16 @@ class FusionLayerNormalization(Fusion):
div_node = None
for child in children:
div_node = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
if div_node is not None:
break
# Check if Sub --> Div exists
div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
# Check if Sub --> Cast --> Div
div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[])
if div_node_1 is not None:
div_node = div_node_1
elif div_node_2 is not None:
div_node = div_node_2[-1]
if div_node is None:
return
@ -63,10 +71,7 @@ class FusionLayerNormalization(Fusion):
div_node,
[
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
(
["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
[1, 0, 0, 0, 0, 0],
),
(["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
],
output_name_to_node,
)
@ -87,7 +92,14 @@ class FusionLayerNormalization(Fusion):
if self.model.find_constant_input(pow_node, 2.0) != 1:
return
mul_node = input_name_to_nodes[div_node.output[0]][0]
temp_node = input_name_to_nodes[div_node.output[0]][0]
if temp_node.op_type == "Cast":
# Div --> Cast --> Mul
subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
mul_node = input_name_to_nodes[temp_node.output[0]][0]
else:
# Div --> Mul
mul_node = temp_node
if mul_node.op_type != "Mul":
return
@ -95,7 +107,7 @@ class FusionLayerNormalization(Fusion):
if last_add_node.op_type != "Add":
return
subgraph_nodes = [node]
subgraph_nodes.append(node)
subgraph_nodes.extend(children)
subgraph_nodes.extend(parent_nodes[:-1])
@ -109,7 +121,8 @@ class FusionLayerNormalization(Fusion):
logger.debug("It is not safe to fuse LayerNormalization node. Skip")
return
weight_input = mul_node.input[1 - self.model.input_index(div_node.output[0], mul_node)]
node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"):
return

View file

@ -0,0 +1,74 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel
logger = logging.getLogger(__name__)
class FusionQuickGelu(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "QuickGelu", ["Mul"])
def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Fuse the following subgraph to `QuickGelu`
#
# root_input
# / \
# | Mul ----+
# | (B = ~1.702) |
# \ | |
# \ Sigmoid |---- `QuickGelu`
# \ / |
# \ / |
# Mul ----+
# |
# root_output
if node.op_type != "Mul":
logger.debug("fuse_quickgelu: failed to match second Mul node")
return
second_mul_node = node
root_input = second_mul_node.input[0]
sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1])
if sigmoid_node is None:
logger.debug("fuse_quickgelu: failed to match Sigmoid node")
return
sigmoid_node = sigmoid_node[0]
first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0])
if first_mul_node is None:
logger.debug("fuse_quickgelu: failed to match first Mul node")
return
first_mul_node = first_mul_node[0]
approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item()
if abs(approximation_value - 1.7021484375) >= 1e-3:
logger.debug("fuse_quickgelu: failed to match approximation value")
return
if first_mul_node.input[0] != root_input:
logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input")
return
new_node = helper.make_node(
"QuickGelu",
inputs=[root_input],
outputs=[second_mul_node.output[0]],
name=self.model.create_node_name("QuickGelu"),
)
new_node.domain = "com.microsoft"
new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)])
self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node])
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.increase_counter("QuickGelu")

View file

@ -21,6 +21,7 @@ from fusion_qordered_attention import FusionQOrderedAttention
from fusion_qordered_gelu import FusionQOrderedGelu
from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
from fusion_qordered_matmul import FusionQOrderedMatMul
from fusion_quickgelu import FusionQuickGelu
from fusion_reshape import FusionReshape
from fusion_rotary_attention import FusionRotaryEmbeddings
from fusion_shape import FusionShape
@ -65,6 +66,8 @@ class BertOnnxModel(OnnxModel):
fusion.apply()
fusion = FusionFastGelu(self)
fusion.apply()
fusion = FusionQuickGelu(self)
fusion.apply()
# Only relevant in models with Q-DQ nodes
fusion = FusionQOrderedGelu(self)
fusion.apply()

View file

@ -25,6 +25,7 @@ class ClipOnnxModel(BertOnnxModel):
ops = [
"Attention",
"LayerNormalization",
"QuickGelu",
"SkipLayerNormalization",
]
for op in ops:

View file

@ -21,6 +21,11 @@ class HuggingfaceFastGelu(torch.nn.Module):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
class HuggingfaceQuickGelu(torch.nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class MegatronGelu(torch.nn.Module):
def forward(self, x):
# The original implementation using ones_like, which might cause problem for input with dynamic axes in onnx.
@ -36,6 +41,7 @@ class MegatronFastGelu(torch.nn.Module):
test_cases = [
("huggingface", "Gelu", HuggingfaceGelu),
("huggingface", "FastGelu", HuggingfaceFastGelu),
("huggingface", "QuickGelu", HuggingfaceQuickGelu),
("megatron", "Gelu", MegatronGelu),
("megatron", "FastGelu", MegatronFastGelu),
]