mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Add fusions for OpenAI CLIP (#20721)
### Description This PR adds fusions for [OpenAI's CLIP model](https://huggingface.co/openai/clip-vit-large-patch14-336). Here is an example of how to run the ORT transformer optimizer for the linked CLIP model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 ``` ### Motivation and Context This PR helps optimize multi-modal models that use CLIP for the vision encoder.
This commit is contained in:
parent
5d07291247
commit
ca22a5a9d0
6 changed files with 167 additions and 31 deletions
|
|
@ -97,7 +97,33 @@ class FusionAttentionClip(FusionAttention):
|
|||
else:
|
||||
# Deal with the first attention after the embedding layer.
|
||||
for i in [0, 1]:
|
||||
node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i)
|
||||
node_before_layer_norm = None
|
||||
|
||||
node_before_layer_norm_1 = self.model.match_parent(normalize_node, "Add", i)
|
||||
node_before_layer_norm_2 = self.model.match_parent(normalize_node, "LayerNormalization", i)
|
||||
if node_before_layer_norm_1 is not None:
|
||||
# Add -----------+
|
||||
# | |
|
||||
# LayerNorm |
|
||||
# | |
|
||||
# LayerNorm |
|
||||
# | |
|
||||
# Attention subgraph |
|
||||
# | |
|
||||
# SkipLayerNorm ------+
|
||||
node_before_layer_norm = node_before_layer_norm_1
|
||||
elif node_before_layer_norm_2 is not None:
|
||||
# Add
|
||||
# |
|
||||
# LayerNorm --------+
|
||||
# | |
|
||||
# LayerNorm |
|
||||
# | |
|
||||
# Attention subgraph |
|
||||
# | |
|
||||
# SkipLayerNorm ------+
|
||||
node_before_layer_norm = node_before_layer_norm_2
|
||||
|
||||
if node_before_layer_norm is None:
|
||||
continue
|
||||
child = self.model.find_first_child_by_type(
|
||||
|
|
@ -130,20 +156,32 @@ class FusionAttentionClip(FusionAttention):
|
|||
return
|
||||
(_, _, reshape_v, add_v, matmul_v) = v_nodes
|
||||
|
||||
add_mask = None
|
||||
add_mask_indices = []
|
||||
qk_nodes = self.model.match_parent_path(
|
||||
qk_nodes = None
|
||||
qk_nodes_1 = self.model.match_parent_path(
|
||||
matmul_qkv,
|
||||
["Softmax", "Reshape", "Add", "Reshape", "MatMul"],
|
||||
[0, 0, 0, None, 0],
|
||||
return_indice=add_mask_indices,
|
||||
)
|
||||
if qk_nodes is None:
|
||||
qk_nodes_2 = self.model.match_parent_path(
|
||||
matmul_qkv,
|
||||
["Softmax", "MatMul"],
|
||||
[0, 0],
|
||||
)
|
||||
if qk_nodes_1 is not None:
|
||||
qk_nodes = qk_nodes_1
|
||||
assert len(add_mask_indices) == 1
|
||||
causal_mask_input_index = 1 - add_mask_indices[0]
|
||||
|
||||
(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes
|
||||
elif qk_nodes_2 is not None:
|
||||
qk_nodes = qk_nodes_2
|
||||
(_softmax_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
logger.debug("fuse_attention: failed to match qk path")
|
||||
return
|
||||
assert len(add_mask_indices) == 1
|
||||
causal_mask_input_index = 1 - add_mask_indices[0]
|
||||
|
||||
(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes
|
||||
|
||||
q_nodes = self.model.match_parent_path(
|
||||
matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None]
|
||||
|
|
@ -172,23 +210,24 @@ class FusionAttentionClip(FusionAttention):
|
|||
|
||||
attention_last_node = reshape_qkv
|
||||
|
||||
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
|
||||
# of computing causal mask.
|
||||
causal_mask_nodes = self.model.match_parent_path(
|
||||
add_mask,
|
||||
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
|
||||
[causal_mask_input_index, 0, 0, 0, 0, 0],
|
||||
)
|
||||
if causal_mask_nodes is None:
|
||||
# If the model is exported with batch_size == 1, there is no Concat node
|
||||
if add_mask is not None:
|
||||
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
|
||||
# of computing causal mask.
|
||||
causal_mask_nodes = self.model.match_parent_path(
|
||||
add_mask,
|
||||
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
|
||||
[causal_mask_input_index, 0, 0, 0, 0],
|
||||
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
|
||||
[causal_mask_input_index, 0, 0, 0, 0, 0],
|
||||
)
|
||||
if causal_mask_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match causal mask subgraph")
|
||||
return
|
||||
# If the model is exported with batch_size == 1, there is no Concat node
|
||||
causal_mask_nodes = self.model.match_parent_path(
|
||||
add_mask,
|
||||
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
|
||||
[causal_mask_input_index, 0, 0, 0, 0],
|
||||
)
|
||||
if causal_mask_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match causal mask subgraph")
|
||||
return
|
||||
|
||||
new_node = self.create_attention_node(
|
||||
mask_index=None,
|
||||
|
|
@ -204,7 +243,7 @@ class FusionAttentionClip(FusionAttention):
|
|||
output=attention_last_node.output[0],
|
||||
add_qk_str=None,
|
||||
scale=None,
|
||||
causal=True,
|
||||
causal=(add_mask is not None),
|
||||
)
|
||||
if new_node is None:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class FusionLayerNormalization(Fusion):
|
|||
| |
|
||||
+----------------------+
|
||||
"""
|
||||
subgraph_nodes = []
|
||||
children = self.model.get_children(node, input_name_to_nodes)
|
||||
if len(children) == 0 or len(children) > 2:
|
||||
return
|
||||
|
|
@ -53,9 +54,16 @@ class FusionLayerNormalization(Fusion):
|
|||
|
||||
div_node = None
|
||||
for child in children:
|
||||
div_node = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
|
||||
if div_node is not None:
|
||||
break
|
||||
# Check if Sub --> Div exists
|
||||
div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
|
||||
|
||||
# Check if Sub --> Cast --> Div
|
||||
div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[])
|
||||
|
||||
if div_node_1 is not None:
|
||||
div_node = div_node_1
|
||||
elif div_node_2 is not None:
|
||||
div_node = div_node_2[-1]
|
||||
if div_node is None:
|
||||
return
|
||||
|
||||
|
|
@ -63,10 +71,7 @@ class FusionLayerNormalization(Fusion):
|
|||
div_node,
|
||||
[
|
||||
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
|
||||
(
|
||||
["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
),
|
||||
(["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
|
||||
],
|
||||
output_name_to_node,
|
||||
)
|
||||
|
|
@ -87,7 +92,14 @@ class FusionLayerNormalization(Fusion):
|
|||
if self.model.find_constant_input(pow_node, 2.0) != 1:
|
||||
return
|
||||
|
||||
mul_node = input_name_to_nodes[div_node.output[0]][0]
|
||||
temp_node = input_name_to_nodes[div_node.output[0]][0]
|
||||
if temp_node.op_type == "Cast":
|
||||
# Div --> Cast --> Mul
|
||||
subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
|
||||
mul_node = input_name_to_nodes[temp_node.output[0]][0]
|
||||
else:
|
||||
# Div --> Mul
|
||||
mul_node = temp_node
|
||||
if mul_node.op_type != "Mul":
|
||||
return
|
||||
|
||||
|
|
@ -95,7 +107,7 @@ class FusionLayerNormalization(Fusion):
|
|||
if last_add_node.op_type != "Add":
|
||||
return
|
||||
|
||||
subgraph_nodes = [node]
|
||||
subgraph_nodes.append(node)
|
||||
subgraph_nodes.extend(children)
|
||||
subgraph_nodes.extend(parent_nodes[:-1])
|
||||
|
||||
|
|
@ -109,7 +121,8 @@ class FusionLayerNormalization(Fusion):
|
|||
logger.debug("It is not safe to fuse LayerNormalization node. Skip")
|
||||
return
|
||||
|
||||
weight_input = mul_node.input[1 - self.model.input_index(div_node.output[0], mul_node)]
|
||||
node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
|
||||
weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
|
||||
if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"):
|
||||
return
|
||||
|
||||
|
|
|
|||
74
onnxruntime/python/tools/transformers/fusion_quickgelu.py
Normal file
74
onnxruntime/python/tools/transformers/fusion_quickgelu.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
|
||||
from fusion_base import Fusion
|
||||
from onnx import helper
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FusionQuickGelu(Fusion):
|
||||
def __init__(self, model: OnnxModel):
|
||||
super().__init__(model, "QuickGelu", ["Mul"])
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
# Fuse the following subgraph to `QuickGelu`
|
||||
#
|
||||
# root_input
|
||||
# / \
|
||||
# | Mul ----+
|
||||
# | (B = ~1.702) |
|
||||
# \ | |
|
||||
# \ Sigmoid |---- `QuickGelu`
|
||||
# \ / |
|
||||
# \ / |
|
||||
# Mul ----+
|
||||
# |
|
||||
# root_output
|
||||
|
||||
if node.op_type != "Mul":
|
||||
logger.debug("fuse_quickgelu: failed to match second Mul node")
|
||||
return
|
||||
|
||||
second_mul_node = node
|
||||
root_input = second_mul_node.input[0]
|
||||
|
||||
sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1])
|
||||
if sigmoid_node is None:
|
||||
logger.debug("fuse_quickgelu: failed to match Sigmoid node")
|
||||
return
|
||||
sigmoid_node = sigmoid_node[0]
|
||||
|
||||
first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0])
|
||||
if first_mul_node is None:
|
||||
logger.debug("fuse_quickgelu: failed to match first Mul node")
|
||||
return
|
||||
first_mul_node = first_mul_node[0]
|
||||
|
||||
approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item()
|
||||
if abs(approximation_value - 1.7021484375) >= 1e-3:
|
||||
logger.debug("fuse_quickgelu: failed to match approximation value")
|
||||
return
|
||||
|
||||
if first_mul_node.input[0] != root_input:
|
||||
logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input")
|
||||
return
|
||||
|
||||
new_node = helper.make_node(
|
||||
"QuickGelu",
|
||||
inputs=[root_input],
|
||||
outputs=[second_mul_node.output[0]],
|
||||
name=self.model.create_node_name("QuickGelu"),
|
||||
)
|
||||
new_node.domain = "com.microsoft"
|
||||
new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)])
|
||||
|
||||
self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node])
|
||||
self.nodes_to_add.append(new_node)
|
||||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
self.increase_counter("QuickGelu")
|
||||
|
|
@ -21,6 +21,7 @@ from fusion_qordered_attention import FusionQOrderedAttention
|
|||
from fusion_qordered_gelu import FusionQOrderedGelu
|
||||
from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
|
||||
from fusion_qordered_matmul import FusionQOrderedMatMul
|
||||
from fusion_quickgelu import FusionQuickGelu
|
||||
from fusion_reshape import FusionReshape
|
||||
from fusion_rotary_attention import FusionRotaryEmbeddings
|
||||
from fusion_shape import FusionShape
|
||||
|
|
@ -65,6 +66,8 @@ class BertOnnxModel(OnnxModel):
|
|||
fusion.apply()
|
||||
fusion = FusionFastGelu(self)
|
||||
fusion.apply()
|
||||
fusion = FusionQuickGelu(self)
|
||||
fusion.apply()
|
||||
# Only relevant in models with Q-DQ nodes
|
||||
fusion = FusionQOrderedGelu(self)
|
||||
fusion.apply()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class ClipOnnxModel(BertOnnxModel):
|
|||
ops = [
|
||||
"Attention",
|
||||
"LayerNormalization",
|
||||
"QuickGelu",
|
||||
"SkipLayerNormalization",
|
||||
]
|
||||
for op in ops:
|
||||
|
|
|
|||
|
|
@ -21,6 +21,11 @@ class HuggingfaceFastGelu(torch.nn.Module):
|
|||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
||||
|
||||
|
||||
class HuggingfaceQuickGelu(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class MegatronGelu(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
# The original implementation using ones_like, which might cause problem for input with dynamic axes in onnx.
|
||||
|
|
@ -36,6 +41,7 @@ class MegatronFastGelu(torch.nn.Module):
|
|||
test_cases = [
|
||||
("huggingface", "Gelu", HuggingfaceGelu),
|
||||
("huggingface", "FastGelu", HuggingfaceFastGelu),
|
||||
("huggingface", "QuickGelu", HuggingfaceQuickGelu),
|
||||
("megatron", "Gelu", MegatronGelu),
|
||||
("megatron", "FastGelu", MegatronFastGelu),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue