mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Add fusion patterns for conformer-transducer model (#18461)
### Description Add conformer-transducer model type to optimizer. This PR adds pattern matches for attention shown below: Unfused attention:  Fused attention: 
This commit is contained in:
parent
53917a3353
commit
97cc40d75a
8 changed files with 802 additions and 3 deletions
|
|
@ -436,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
|
|||
file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx"
|
||||
)
|
||||
endif()
|
||||
|
||||
file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
|
||||
|
|
@ -549,6 +552,7 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/eager_test
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${ONNXRUNTIME_ROOT}/__init__.py
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/
|
||||
|
|
@ -701,6 +705,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_testdata_whisper}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_testdata_conformer}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer/
|
||||
)
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -657,7 +657,6 @@ class FusionAttention(Fusion):
|
|||
return None
|
||||
|
||||
graph_input_names = set([node.name for node in self.model.graph().input])
|
||||
graph_output_names = set([node.name for node in self.model.graph().output])
|
||||
mha_node_name = self.model.create_node_name("Attention")
|
||||
|
||||
# Add initial Q/K/V inputs for MHA
|
||||
|
|
@ -693,12 +692,15 @@ class FusionAttention(Fusion):
|
|||
mha_inputs.append("")
|
||||
|
||||
# Add optional inputs for MHA
|
||||
if past_k and past_v and past_k in graph_input_names and past_v in graph_input_names:
|
||||
|
||||
if past_k and past_v:
|
||||
mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
|
||||
elif key_padding_mask or add_qk:
|
||||
mha_inputs.extend([key_padding_mask, add_qk])
|
||||
|
||||
# Add outputs for MHA
|
||||
mha_outputs = [output]
|
||||
if present_k and present_v and present_k in graph_output_names and present_v in graph_output_names:
|
||||
if present_k and present_v:
|
||||
mha_outputs.extend([present_k, present_v])
|
||||
|
||||
mha_node = helper.make_node(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,143 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
|
||||
from fusion_attention import AttentionMask, FusionAttention
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FusionConformerAttention(FusionAttention):
|
||||
"""
|
||||
Fuse Conformer Attention subgraph into one MultiHeadAttention node.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: OnnxModel,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
attention_mask: AttentionMask,
|
||||
):
|
||||
super().__init__(model, hidden_size, num_heads, attention_mask)
|
||||
|
||||
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
||||
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
normalize_node,
|
||||
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
||||
[1, 1, 0, 0, 0],
|
||||
)
|
||||
if qkv_nodes is not None:
|
||||
(
|
||||
_,
|
||||
_,
|
||||
reshape_qkv,
|
||||
transpose_qkv,
|
||||
matmul_qkv,
|
||||
) = qkv_nodes
|
||||
else:
|
||||
logger.debug("fuse_conformer_attention: failed to match qkv path")
|
||||
return
|
||||
|
||||
v_nodes = self.model.match_parent_path(
|
||||
matmul_qkv,
|
||||
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
|
||||
[1, 1, 0, 0, 1],
|
||||
)
|
||||
|
||||
add_v = None
|
||||
if v_nodes is not None:
|
||||
(concat_v, _, _, add_v, matmul_v) = v_nodes
|
||||
concat_parent = self.model.get_parent(concat_v, 0, None)
|
||||
present_v = concat_v.output[0]
|
||||
past_v = concat_parent.output[0]
|
||||
else:
|
||||
logger.debug("fuse_conformer_attention: failed to match v path")
|
||||
return
|
||||
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
|
||||
|
||||
if qk_nodes is not None:
|
||||
_, add_qk, matmul_qk = qk_nodes
|
||||
else:
|
||||
logger.debug("fuse_conformer_attention: failed to match qk path")
|
||||
return
|
||||
|
||||
q_nodes = self.model.match_parent_path(
|
||||
matmul_qk,
|
||||
["Div", "Transpose", "Reshape", "Add", "MatMul"],
|
||||
[0, 0, 0, 0, 1],
|
||||
)
|
||||
if q_nodes is not None:
|
||||
_, _, reshape_q, add_q, matmul_q = q_nodes
|
||||
else:
|
||||
logger.debug("fuse_conformer_attention: failed to match q path")
|
||||
return
|
||||
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk,
|
||||
["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
|
||||
[1, 0, 1, 0, 0, 1],
|
||||
)
|
||||
|
||||
matmul_k = None
|
||||
if k_nodes is not None:
|
||||
_, concat_k, _, _, add_k, matmul_k = k_nodes
|
||||
concat_parent = self.model.get_parent(concat_k, 0, None)
|
||||
past_k = concat_parent.output[0]
|
||||
present_k = concat_k.output[0]
|
||||
else:
|
||||
logger.debug("fuse_conformer_attention: failed to match k path")
|
||||
return
|
||||
|
||||
attention_last_node = reshape_qkv
|
||||
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
||||
|
||||
if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
|
||||
logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
|
||||
return
|
||||
|
||||
new_node = self.create_multihead_attention_node(
|
||||
matmul_q,
|
||||
matmul_k,
|
||||
matmul_v,
|
||||
add_q,
|
||||
add_k,
|
||||
add_v,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
attention_last_node.output[0],
|
||||
add_qk=add_qk.input[1],
|
||||
past_k=past_k,
|
||||
past_v=past_v,
|
||||
present_k=present_k,
|
||||
present_v=present_v,
|
||||
)
|
||||
|
||||
if new_node is None:
|
||||
logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
|
||||
return
|
||||
|
||||
self.nodes_to_add.append(new_node)
|
||||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
|
||||
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
|
||||
self.nodes_to_remove.extend(qk_nodes)
|
||||
|
||||
# When using multihead attention, keep MatMul nodes in original graph
|
||||
if q_nodes[-1].op_type == "MatMul":
|
||||
q_nodes.pop()
|
||||
if k_nodes[-1].op_type == "MatMul":
|
||||
k_nodes.pop()
|
||||
if v_nodes[-1].op_type == "MatMul":
|
||||
v_nodes.pop()
|
||||
|
||||
self.nodes_to_remove.extend(k_nodes)
|
||||
self.nodes_to_remove.extend(v_nodes)
|
||||
|
||||
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
||||
self.prune_graph = True
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fusion_attention import AttentionMask
|
||||
from fusion_conformer_attention import FusionConformerAttention
|
||||
from fusion_options import FusionOptions
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConformerOnnxModel(BertOnnxModel):
|
||||
def __init__(self, model, num_heads, hidden_size):
|
||||
super().__init__(model, num_heads, hidden_size)
|
||||
self.attention_mask = AttentionMask(self)
|
||||
self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
|
||||
|
||||
def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
|
||||
self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
|
||||
self.attention_fusion.disable_multi_head_attention_bias = (
|
||||
False if options is None else options.disable_multi_head_attention_bias
|
||||
)
|
||||
super().optimize(options, add_dynamic_axes)
|
||||
|
||||
def fuse_attention(self):
|
||||
self.attention_fusion.apply()
|
||||
|
||||
def preprocess(self):
|
||||
self.adjust_reshape_and_expand()
|
||||
|
|
@ -32,6 +32,7 @@ from onnx_model_bert import BertOnnxModel
|
|||
from onnx_model_bert_keras import BertOnnxModelKeras
|
||||
from onnx_model_bert_tf import BertOnnxModelTF
|
||||
from onnx_model_clip import ClipOnnxModel
|
||||
from onnx_model_conformer import ConformerOnnxModel
|
||||
from onnx_model_gpt2 import Gpt2OnnxModel
|
||||
from onnx_model_t5 import T5OnnxModel
|
||||
from onnx_model_tnlr import TnlrOnnxModel
|
||||
|
|
@ -56,6 +57,7 @@ MODEL_TYPES = {
|
|||
"unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion
|
||||
"vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion
|
||||
"vit": (BertOnnxModel, "pytorch", 1),
|
||||
"conformer": (ConformerOnnxModel, "pytorch", 1),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,543 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from bert_model_generator import float_tensor
|
||||
from onnx import TensorProto, helper, numpy_helper
|
||||
|
||||
|
||||
# Adapted from bert_model_generator.py
|
||||
def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False):
|
||||
low = 0.0
|
||||
high = 1.0
|
||||
total_elements = 1
|
||||
for x in shape:
|
||||
total_elements *= x
|
||||
weights = (
|
||||
[np.random.uniform(low, high) for _ in range(total_elements)]
|
||||
if random
|
||||
else [0.0] * total_elements
|
||||
if zeros
|
||||
else [1.0] * total_elements
|
||||
)
|
||||
return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights
|
||||
|
||||
|
||||
def create_conformer_attention(
|
||||
hidden_size=512,
|
||||
num_heads=8,
|
||||
epsilon=0.000009999999747378752,
|
||||
add_before_layernorm=False,
|
||||
fused=False,
|
||||
):
|
||||
# Get head size and ensure head size is an integer
|
||||
assert hidden_size % num_heads == 0
|
||||
head_size = hidden_size // num_heads
|
||||
|
||||
# Construct input and output nodes
|
||||
inputs = [
|
||||
helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 8, 512]),
|
||||
helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["batch_size", 8, 512]),
|
||||
helper.make_tensor_value_info("inp_cache_k", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]),
|
||||
helper.make_tensor_value_info("inp_cache_v", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]),
|
||||
]
|
||||
outputs = [
|
||||
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 8, hidden_size]),
|
||||
helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", 8, 512]),
|
||||
helper.make_tensor_value_info("oup_cache_k", TensorProto.FLOAT, ["batch_size", 8, 80, 64]),
|
||||
helper.make_tensor_value_info("oup_cache_v", TensorProto.FLOAT, ["batch_size", 8, 80, 64]),
|
||||
]
|
||||
nodes = []
|
||||
|
||||
# Create layernorm (Add + LayerNorm or SkipLayerNorm)
|
||||
if add_before_layernorm:
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node(
|
||||
"Add", ["input_0", "input_1"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm"
|
||||
),
|
||||
helper.make_node(
|
||||
"LayerNormalization",
|
||||
["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"],
|
||||
["layernorm_add_output_to_matmul"],
|
||||
"layernorm",
|
||||
epsilon=epsilon,
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
nodes.append(
|
||||
helper.make_node(
|
||||
"SkipLayerNormalization",
|
||||
["input_0", "input_1", "layernorm_weight", "layernorm_bias"],
|
||||
["layernorm_add_output_to_matmul", "", "", "layernorm_add_output_to_skiplayernorm"],
|
||||
"skiplayernorm",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
)
|
||||
)
|
||||
|
||||
if fused:
|
||||
fused_q_nodes = [
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["layernorm_add_output_to_matmul", "q_weight"],
|
||||
["q_matmul_output"],
|
||||
"q_path_matmul",
|
||||
),
|
||||
helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"),
|
||||
helper.make_node(
|
||||
"Reshape", ["q_add_output", "k_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d", allowzero=0
|
||||
),
|
||||
helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]),
|
||||
helper.make_node(
|
||||
"Div",
|
||||
["q_4d_bnsh", "q_scale"],
|
||||
["q_div_output"],
|
||||
"q_div_by_sqrt_head_size",
|
||||
),
|
||||
]
|
||||
nodes.extend(fused_q_nodes)
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["layernorm_add_output_to_matmul", "k_weight"],
|
||||
["k_matmul_output"],
|
||||
"k_path_matmul",
|
||||
),
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["layernorm_add_output_to_matmul", "v_weight"],
|
||||
["v_matmul_output"],
|
||||
"v_path_matmul",
|
||||
),
|
||||
helper.make_node(
|
||||
"Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb", allowzero=0
|
||||
),
|
||||
helper.make_node(
|
||||
"Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2]
|
||||
),
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"],
|
||||
["pos_matmul"],
|
||||
"pos_embed_matmul",
|
||||
),
|
||||
helper.make_node(
|
||||
"Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2]
|
||||
),
|
||||
helper.make_node(
|
||||
"Reshape",
|
||||
["transpose_pos_matmul", "position_embed_output"],
|
||||
["reshape_position_emb"],
|
||||
"final_reshape_pos_emb",
|
||||
allowzero=0,
|
||||
),
|
||||
helper.make_node(
|
||||
"MultiHeadAttention",
|
||||
[
|
||||
"q_matmul_output",
|
||||
"k_matmul_output",
|
||||
"v_matmul_output",
|
||||
"Attention_0_qkv_bias",
|
||||
"",
|
||||
"reshape_position_emb",
|
||||
"gather_past_k_output",
|
||||
"gather_past_v_output",
|
||||
],
|
||||
["attn_output", "oup_cache_k", "oup_cache_v"],
|
||||
"Attention_0",
|
||||
domain="com.microsoft",
|
||||
num_heads=num_heads,
|
||||
),
|
||||
]
|
||||
)
|
||||
# Create nodes used with qkv concats, reshapes, and transposes
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0),
|
||||
helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0),
|
||||
helper.make_node(
|
||||
"Mul",
|
||||
["gather_0_output", "num_heads_int"],
|
||||
["mul_attn_heads_output"],
|
||||
"mul_num_heads",
|
||||
),
|
||||
helper.make_node(
|
||||
"Unsqueeze",
|
||||
["mul_attn_heads_output", "unsqueeze_axes_input"],
|
||||
["unsqueeze_position_embed"],
|
||||
"unsqueeze_position_embed",
|
||||
),
|
||||
helper.make_node(
|
||||
"Concat",
|
||||
["unsqueeze_position_embed", "neg_one", "head_size"],
|
||||
["position_embed_output"],
|
||||
"position_embed_concat_output",
|
||||
axis=0,
|
||||
),
|
||||
helper.make_node(
|
||||
"Unsqueeze",
|
||||
["gather_0_output", "unsqueeze_axes_input"],
|
||||
["unsqueeze_attn_heads_output"],
|
||||
"unsqueeze_num_heads",
|
||||
),
|
||||
helper.make_node(
|
||||
"Concat",
|
||||
["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"],
|
||||
["k_attn_heads_output"],
|
||||
"k_num_heads",
|
||||
axis=0,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0),
|
||||
helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Create nodes for Q/K/V paths
|
||||
q_nodes = [
|
||||
helper.make_node(
|
||||
"MatMul", ["layernorm_add_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul"
|
||||
),
|
||||
helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"),
|
||||
helper.make_node("Reshape", ["q_add_output", "q_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d"),
|
||||
helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]),
|
||||
helper.make_node(
|
||||
"Div",
|
||||
["q_4d_bnsh", "q_scale"],
|
||||
["q_div_output"],
|
||||
"q_div_by_sqrt_head_size",
|
||||
),
|
||||
]
|
||||
k_nodes = [
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["layernorm_add_output_to_matmul", "k_weight"],
|
||||
["k_matmul_output"],
|
||||
"k_path_matmul",
|
||||
),
|
||||
helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"),
|
||||
helper.make_node("Reshape", ["k_add_output", "k_attn_heads_output"], ["k_4d_bsnh"], "k_reshape_to_4d"),
|
||||
helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]),
|
||||
helper.make_node(
|
||||
"Concat",
|
||||
["gather_past_k_output", "k_4d_bnsh"],
|
||||
["oup_cache_k"],
|
||||
"concat_past_k_and_curr_k",
|
||||
axis=2,
|
||||
),
|
||||
helper.make_node(
|
||||
"Transpose",
|
||||
["oup_cache_k"],
|
||||
["k_output_transpose"],
|
||||
"k_transpose_last_two_dims",
|
||||
perm=[0, 1, 3, 2],
|
||||
),
|
||||
]
|
||||
v_nodes = [
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["layernorm_add_output_to_matmul", "v_weight"],
|
||||
["v_matmul_output"],
|
||||
"v_path_matmul",
|
||||
),
|
||||
helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"),
|
||||
helper.make_node("Reshape", ["v_add_output", "v_attn_heads_output"], ["v_4d_bsnh"], "v_reshape_to_4d"),
|
||||
helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]),
|
||||
helper.make_node(
|
||||
"Concat",
|
||||
["gather_past_v_output", "v_4d_bnsh"],
|
||||
["oup_cache_v"],
|
||||
"concat_past_v_and_curr_v",
|
||||
axis=2,
|
||||
),
|
||||
]
|
||||
pos_embed = [
|
||||
helper.make_node("Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb"),
|
||||
helper.make_node(
|
||||
"Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2]
|
||||
),
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"],
|
||||
["pos_matmul"],
|
||||
"pos_embed_matmul",
|
||||
),
|
||||
helper.make_node(
|
||||
"Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2]
|
||||
),
|
||||
helper.make_node(
|
||||
"Reshape",
|
||||
["transpose_pos_matmul", "position_embed_output"],
|
||||
["reshape_position_emb"],
|
||||
"final_reshape_pos_emb",
|
||||
),
|
||||
]
|
||||
nodes.extend(q_nodes)
|
||||
nodes.extend(k_nodes)
|
||||
nodes.extend(v_nodes)
|
||||
nodes.extend(pos_embed)
|
||||
|
||||
# Create nodes used with qkv concats, reshapes, and transposes
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0),
|
||||
helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0),
|
||||
helper.make_node(
|
||||
"Mul",
|
||||
["gather_0_output", "num_heads_int"],
|
||||
["mul_attn_heads_output"],
|
||||
"mul_num_heads",
|
||||
),
|
||||
helper.make_node(
|
||||
"Unsqueeze",
|
||||
["mul_attn_heads_output", "unsqueeze_axes_input"],
|
||||
["unsqueeze_position_embed"],
|
||||
"unsqueeze_position_embed",
|
||||
),
|
||||
helper.make_node(
|
||||
"Concat",
|
||||
["unsqueeze_position_embed", "neg_one", "head_size"],
|
||||
["position_embed_output"],
|
||||
"position_embed_concat_output",
|
||||
axis=0,
|
||||
),
|
||||
helper.make_node(
|
||||
"Unsqueeze",
|
||||
["gather_0_output", "unsqueeze_axes_input"],
|
||||
["unsqueeze_attn_heads_output"],
|
||||
"unsqueeze_num_heads",
|
||||
),
|
||||
helper.make_node(
|
||||
"Concat",
|
||||
["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"],
|
||||
["q_attn_heads_output"],
|
||||
"q_num_heads",
|
||||
axis=0,
|
||||
),
|
||||
helper.make_node(
|
||||
"Concat",
|
||||
["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"],
|
||||
["k_attn_heads_output"],
|
||||
"k_num_heads",
|
||||
axis=0,
|
||||
),
|
||||
helper.make_node(
|
||||
"Concat",
|
||||
["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"],
|
||||
["v_attn_heads_output"],
|
||||
"v_num_heads",
|
||||
axis=0,
|
||||
),
|
||||
helper.make_node(
|
||||
"Concat",
|
||||
["unsqueeze_attn_heads_output", "neg_one", "head_size"],
|
||||
["bsd_format"],
|
||||
axis=0,
|
||||
),
|
||||
helper.make_node(
|
||||
"Constant",
|
||||
inputs=[],
|
||||
outputs=["q_bsnh_reshape"],
|
||||
value=numpy_helper.from_array(
|
||||
np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor"
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0),
|
||||
helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0),
|
||||
]
|
||||
)
|
||||
|
||||
# Compute Q x K'
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
[
|
||||
"q_div_output",
|
||||
"k_output_transpose",
|
||||
],
|
||||
["qk_output"],
|
||||
"matmul_qk",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Create nodes for computing softmax(Q x K') x V
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node(
|
||||
"Add",
|
||||
[
|
||||
"qk_output",
|
||||
"reshape_position_emb",
|
||||
],
|
||||
["add_qk_output"],
|
||||
"add_qk",
|
||||
),
|
||||
helper.make_node(
|
||||
"Softmax",
|
||||
["add_qk_output"],
|
||||
["softmax_output"],
|
||||
"softmax_qk",
|
||||
axis=2,
|
||||
),
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["softmax_output", "oup_cache_v"],
|
||||
["qkv_output_(num_heads*batch_size,seq_len,head_size)"],
|
||||
"matmul_qkv",
|
||||
),
|
||||
helper.make_node(
|
||||
"Transpose",
|
||||
["qkv_output_(num_heads*batch_size,seq_len,head_size)"],
|
||||
["qkv_bsnh"],
|
||||
"transpose_bnsh_to_bsnh",
|
||||
perm=[0, 2, 1, 3],
|
||||
),
|
||||
helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"),
|
||||
]
|
||||
)
|
||||
|
||||
# Create final nodes to conclude attention
|
||||
nodes.append(
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["attn_output", "matmul_after_attn_initializer"],
|
||||
["matmul_after_attn_output"],
|
||||
"matmul_after_attn",
|
||||
),
|
||||
)
|
||||
if not fused:
|
||||
next_sln_inputs = [
|
||||
"layernorm_add_output_to_skiplayernorm",
|
||||
"add_after_attn_output",
|
||||
"layernorm_weight",
|
||||
"layernorm_bias",
|
||||
]
|
||||
nodes.extend(
|
||||
[
|
||||
helper.make_node(
|
||||
"Add",
|
||||
["add_after_attn_initializer", "matmul_after_attn_output"],
|
||||
["add_after_attn_output"],
|
||||
"add_after_attn",
|
||||
),
|
||||
helper.make_node(
|
||||
"SkipLayerNormalization",
|
||||
next_sln_inputs,
|
||||
["output_0", "", "", "output_1"],
|
||||
"next_skiplayernorm",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
next_sln_inputs = [
|
||||
"matmul_after_attn_output",
|
||||
"layernorm_add_output_to_skiplayernorm",
|
||||
"layernorm_weight",
|
||||
"layernorm_bias",
|
||||
"add_after_attn_initializer",
|
||||
]
|
||||
nodes.append(
|
||||
helper.make_node(
|
||||
"SkipLayerNormalization",
|
||||
next_sln_inputs,
|
||||
["output_0", "", "", "output_1"],
|
||||
"SkipLayerNorm_AddBias_0",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
)
|
||||
)
|
||||
|
||||
# Create initializers
|
||||
v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size])
|
||||
v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size])
|
||||
q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size])
|
||||
q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size])
|
||||
k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size])
|
||||
k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size])
|
||||
|
||||
qkv_bias = helper.make_tensor(
|
||||
"Attention_0_qkv_bias",
|
||||
TensorProto.FLOAT,
|
||||
[3 * hidden_size],
|
||||
q_bias_data + k_bias_data + v_bias_data,
|
||||
)
|
||||
initializers = [
|
||||
float_tensor("layernorm_weight", [hidden_size]),
|
||||
float_tensor("layernorm_bias", [hidden_size]),
|
||||
float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]),
|
||||
float_tensor("add_after_attn_initializer", [hidden_size]),
|
||||
]
|
||||
|
||||
# Add Q/K/V weight tensors as initializers
|
||||
if fused:
|
||||
initializers.extend([q_weight, k_weight, v_weight])
|
||||
initializers.extend([q_bias])
|
||||
initializers.append(qkv_bias)
|
||||
initializers.extend(
|
||||
[
|
||||
numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"),
|
||||
numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"),
|
||||
numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"),
|
||||
numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"),
|
||||
numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"),
|
||||
numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"),
|
||||
numpy_helper.from_array(np.array([0, 0, num_heads, head_size], dtype="int64"), name="q_bsnh_reshape"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
initializers.extend([q_weight, k_weight, v_weight])
|
||||
|
||||
initializers.extend([q_bias, k_bias, v_bias])
|
||||
|
||||
initializers.extend(
|
||||
[
|
||||
numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"),
|
||||
numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"),
|
||||
numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"),
|
||||
numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"),
|
||||
numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"),
|
||||
numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"),
|
||||
numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"),
|
||||
numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"),
|
||||
numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"),
|
||||
]
|
||||
)
|
||||
|
||||
# Construct graph
|
||||
graph = helper.make_graph(nodes, "conformer_self_mha_graph", inputs, outputs, initializers, doc_string="conformer")
|
||||
opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16))
|
||||
return helper.make_model(graph, opset_imports=(opsetid,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
np.random.seed(2)
|
||||
num_heads = 8
|
||||
hidden_size = 512
|
||||
|
||||
model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size)
|
||||
onnx.save(model, "conformer_self_mha.onnx")
|
||||
|
||||
model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True)
|
||||
onnx.save(model, "./test_data/models/conformer/conformer_self_mha_fused.onnx")
|
||||
69
onnxruntime/test/python/transformers/test_conformer.py
Normal file
69
onnxruntime/test/python/transformers/test_conformer.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import onnx
|
||||
from conformer_model_generator import create_conformer_attention
|
||||
from parity_utilities import find_transformers_source
|
||||
|
||||
if find_transformers_source():
|
||||
from fusion_options import FusionOptions
|
||||
from onnx_model import OnnxModel
|
||||
from optimizer import optimize_model
|
||||
else:
|
||||
from onnxruntime.transformers.fusion_options import FusionOptions
|
||||
from onnxruntime.transformers.onnx_model import OnnxModel
|
||||
from onnxruntime.transformers.optimizer import optimize_model
|
||||
|
||||
|
||||
class TestFusion(unittest.TestCase):
|
||||
def verify_fusion(self, optimized_model, expected_model_filename):
|
||||
optimized_model.topological_sort(is_deterministic=True)
|
||||
|
||||
expected_model_path = os.path.join(
|
||||
os.path.dirname(__file__), "test_data", "models", "conformer", expected_model_filename
|
||||
)
|
||||
print("Expected model path = ", expected_model_path)
|
||||
expected_model = OnnxModel(onnx.load(expected_model_path))
|
||||
expected_model.topological_sort(is_deterministic=True)
|
||||
|
||||
nodes = optimized_model.model.graph.node
|
||||
self.assertEqual(len(nodes), len(expected_model.model.graph.node))
|
||||
|
||||
for i in range(len(nodes)):
|
||||
self.assertEqual(nodes[i], expected_model.model.graph.node[i])
|
||||
|
||||
for expected_initializer in expected_model.model.graph.initializer:
|
||||
print("Expected initializer initial = ", expected_initializer.name)
|
||||
self.assertTrue(
|
||||
OnnxModel.has_same_value(
|
||||
optimized_model.get_initializer(expected_initializer.name), expected_initializer
|
||||
)
|
||||
)
|
||||
|
||||
def test_ct_mha_fusion(self):
|
||||
num_heads = 8
|
||||
hidden_size = 512
|
||||
model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False)
|
||||
dir = "."
|
||||
model_path = os.path.join(dir, "conformer_self_mha.onnx")
|
||||
onnx.save(model, model_path)
|
||||
options = FusionOptions("conformer")
|
||||
optimized_model = optimize_model(
|
||||
model_path,
|
||||
model_type="conformer",
|
||||
num_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
optimization_options=options,
|
||||
)
|
||||
os.remove(model_path)
|
||||
self.verify_fusion(optimized_model, "conformer_self_mha_fused.onnx")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Binary file not shown.
Loading…
Reference in a new issue