Add fusion patterns for conformer-transducer model (#18461)

### Description
Add conformer-transducer model type to optimizer. This PR adds pattern
matches for attention shown below:
Unfused attention:

![ct_unfused](https://github.com/microsoft/onnxruntime/assets/111780983/46c71ed8-67e0-4607-85b1-bcadba5a2956)

Fused attention:

![ct_fused](https://github.com/microsoft/onnxruntime/assets/111780983/fbb91c96-0d4b-4f0b-8674-1ae3b9b9a92e)
This commit is contained in:
Akshay Sonawane 2023-11-18 23:39:04 -08:00 committed by GitHub
parent 53917a3353
commit 97cc40d75a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 802 additions and 3 deletions

View file

@ -436,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx"
)
file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx"
)
endif()
file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
@ -549,6 +552,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/eager_test
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer
COMMAND ${CMAKE_COMMAND} -E copy
${ONNXRUNTIME_ROOT}/__init__.py
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/
@ -701,6 +705,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_whisper}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_conformer}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer/
)
endif()

View file

@ -657,7 +657,6 @@ class FusionAttention(Fusion):
return None
graph_input_names = set([node.name for node in self.model.graph().input])
graph_output_names = set([node.name for node in self.model.graph().output])
mha_node_name = self.model.create_node_name("Attention")
# Add initial Q/K/V inputs for MHA
@ -693,12 +692,15 @@ class FusionAttention(Fusion):
mha_inputs.append("")
# Add optional inputs for MHA
if past_k and past_v and past_k in graph_input_names and past_v in graph_input_names:
if past_k and past_v:
mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
elif key_padding_mask or add_qk:
mha_inputs.extend([key_padding_mask, add_qk])
# Add outputs for MHA
mha_outputs = [output]
if present_k and present_v and present_k in graph_output_names and present_v in graph_output_names:
if present_k and present_v:
mha_outputs.extend([present_k, present_v])
mha_node = helper.make_node(

View file

@ -0,0 +1,143 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from fusion_attention import AttentionMask, FusionAttention
from onnx_model import OnnxModel
logger = logging.getLogger(__name__)
class FusionConformerAttention(FusionAttention):
"""
Fuse Conformer Attention subgraph into one MultiHeadAttention node.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 1, 0, 0, 0],
)
if qkv_nodes is not None:
(
_,
_,
reshape_qkv,
transpose_qkv,
matmul_qkv,
) = qkv_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qkv path")
return
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 1, 0, 0, 1],
)
add_v = None
if v_nodes is not None:
(concat_v, _, _, add_v, matmul_v) = v_nodes
concat_parent = self.model.get_parent(concat_v, 0, None)
present_v = concat_v.output[0]
past_v = concat_parent.output[0]
else:
logger.debug("fuse_conformer_attention: failed to match v path")
return
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
if qk_nodes is not None:
_, add_qk, matmul_qk = qk_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qk path")
return
q_nodes = self.model.match_parent_path(
matmul_qk,
["Div", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, 1],
)
if q_nodes is not None:
_, _, reshape_q, add_q, matmul_q = q_nodes
else:
logger.debug("fuse_conformer_attention: failed to match q path")
return
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 1, 0, 0, 1],
)
matmul_k = None
if k_nodes is not None:
_, concat_k, _, _, add_k, matmul_k = k_nodes
concat_parent = self.model.get_parent(concat_k, 0, None)
past_k = concat_parent.output[0]
present_k = concat_k.output[0]
else:
logger.debug("fuse_conformer_attention: failed to match k path")
return
attention_last_node = reshape_qkv
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
return
new_node = self.create_multihead_attention_node(
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
num_heads,
hidden_size,
attention_last_node.output[0],
add_qk=add_qk.input[1],
past_k=past_k,
past_v=past_v,
present_k=present_k,
present_v=present_v,
)
if new_node is None:
logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)
# When using multihead attention, keep MatMul nodes in original graph
if q_nodes[-1].op_type == "MatMul":
q_nodes.pop()
if k_nodes[-1].op_type == "MatMul":
k_nodes.pop()
if v_nodes[-1].op_type == "MatMul":
v_nodes.pop()
self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
self.prune_graph = True

View file

@ -0,0 +1,33 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Optional
from fusion_attention import AttentionMask
from fusion_conformer_attention import FusionConformerAttention
from fusion_options import FusionOptions
from onnx_model_bert import BertOnnxModel
logger = logging.getLogger(__name__)
class ConformerOnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
self.attention_fusion.disable_multi_head_attention_bias = (
False if options is None else options.disable_multi_head_attention_bias
)
super().optimize(options, add_dynamic_axes)
def fuse_attention(self):
self.attention_fusion.apply()
def preprocess(self):
self.adjust_reshape_and_expand()

View file

@ -32,6 +32,7 @@ from onnx_model_bert import BertOnnxModel
from onnx_model_bert_keras import BertOnnxModelKeras
from onnx_model_bert_tf import BertOnnxModelTF
from onnx_model_clip import ClipOnnxModel
from onnx_model_conformer import ConformerOnnxModel
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_t5 import T5OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
@ -56,6 +57,7 @@ MODEL_TYPES = {
"unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion
"vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion
"vit": (BertOnnxModel, "pytorch", 1),
"conformer": (ConformerOnnxModel, "pytorch", 1),
}

View file

@ -0,0 +1,543 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import List
import numpy as np
import onnx
from bert_model_generator import float_tensor
from onnx import TensorProto, helper, numpy_helper
# Adapted from bert_model_generator.py
def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False):
low = 0.0
high = 1.0
total_elements = 1
for x in shape:
total_elements *= x
weights = (
[np.random.uniform(low, high) for _ in range(total_elements)]
if random
else [0.0] * total_elements
if zeros
else [1.0] * total_elements
)
return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights
def create_conformer_attention(
hidden_size=512,
num_heads=8,
epsilon=0.000009999999747378752,
add_before_layernorm=False,
fused=False,
):
# Get head size and ensure head size is an integer
assert hidden_size % num_heads == 0
head_size = hidden_size // num_heads
# Construct input and output nodes
inputs = [
helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 8, 512]),
helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["batch_size", 8, 512]),
helper.make_tensor_value_info("inp_cache_k", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]),
helper.make_tensor_value_info("inp_cache_v", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]),
]
outputs = [
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 8, hidden_size]),
helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", 8, 512]),
helper.make_tensor_value_info("oup_cache_k", TensorProto.FLOAT, ["batch_size", 8, 80, 64]),
helper.make_tensor_value_info("oup_cache_v", TensorProto.FLOAT, ["batch_size", 8, 80, 64]),
]
nodes = []
# Create layernorm (Add + LayerNorm or SkipLayerNorm)
if add_before_layernorm:
nodes.extend(
[
helper.make_node(
"Add", ["input_0", "input_1"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm"
),
helper.make_node(
"LayerNormalization",
["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"],
["layernorm_add_output_to_matmul"],
"layernorm",
epsilon=epsilon,
),
]
)
else:
nodes.append(
helper.make_node(
"SkipLayerNormalization",
["input_0", "input_1", "layernorm_weight", "layernorm_bias"],
["layernorm_add_output_to_matmul", "", "", "layernorm_add_output_to_skiplayernorm"],
"skiplayernorm",
domain="com.microsoft",
epsilon=epsilon,
)
)
if fused:
fused_q_nodes = [
helper.make_node(
"MatMul",
["layernorm_add_output_to_matmul", "q_weight"],
["q_matmul_output"],
"q_path_matmul",
),
helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"),
helper.make_node(
"Reshape", ["q_add_output", "k_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d", allowzero=0
),
helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]),
helper.make_node(
"Div",
["q_4d_bnsh", "q_scale"],
["q_div_output"],
"q_div_by_sqrt_head_size",
),
]
nodes.extend(fused_q_nodes)
nodes.extend(
[
helper.make_node(
"MatMul",
["layernorm_add_output_to_matmul", "k_weight"],
["k_matmul_output"],
"k_path_matmul",
),
helper.make_node(
"MatMul",
["layernorm_add_output_to_matmul", "v_weight"],
["v_matmul_output"],
"v_path_matmul",
),
helper.make_node(
"Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb", allowzero=0
),
helper.make_node(
"Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2]
),
helper.make_node(
"MatMul",
["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"],
["pos_matmul"],
"pos_embed_matmul",
),
helper.make_node(
"Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2]
),
helper.make_node(
"Reshape",
["transpose_pos_matmul", "position_embed_output"],
["reshape_position_emb"],
"final_reshape_pos_emb",
allowzero=0,
),
helper.make_node(
"MultiHeadAttention",
[
"q_matmul_output",
"k_matmul_output",
"v_matmul_output",
"Attention_0_qkv_bias",
"",
"reshape_position_emb",
"gather_past_k_output",
"gather_past_v_output",
],
["attn_output", "oup_cache_k", "oup_cache_v"],
"Attention_0",
domain="com.microsoft",
num_heads=num_heads,
),
]
)
# Create nodes used with qkv concats, reshapes, and transposes
nodes.extend(
[
helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0),
helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0),
helper.make_node(
"Mul",
["gather_0_output", "num_heads_int"],
["mul_attn_heads_output"],
"mul_num_heads",
),
helper.make_node(
"Unsqueeze",
["mul_attn_heads_output", "unsqueeze_axes_input"],
["unsqueeze_position_embed"],
"unsqueeze_position_embed",
),
helper.make_node(
"Concat",
["unsqueeze_position_embed", "neg_one", "head_size"],
["position_embed_output"],
"position_embed_concat_output",
axis=0,
),
helper.make_node(
"Unsqueeze",
["gather_0_output", "unsqueeze_axes_input"],
["unsqueeze_attn_heads_output"],
"unsqueeze_num_heads",
),
helper.make_node(
"Concat",
["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"],
["k_attn_heads_output"],
"k_num_heads",
axis=0,
),
]
)
nodes.extend(
[
helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0),
helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0),
]
)
else:
# Create nodes for Q/K/V paths
q_nodes = [
helper.make_node(
"MatMul", ["layernorm_add_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul"
),
helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"),
helper.make_node("Reshape", ["q_add_output", "q_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d"),
helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]),
helper.make_node(
"Div",
["q_4d_bnsh", "q_scale"],
["q_div_output"],
"q_div_by_sqrt_head_size",
),
]
k_nodes = [
helper.make_node(
"MatMul",
["layernorm_add_output_to_matmul", "k_weight"],
["k_matmul_output"],
"k_path_matmul",
),
helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"),
helper.make_node("Reshape", ["k_add_output", "k_attn_heads_output"], ["k_4d_bsnh"], "k_reshape_to_4d"),
helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]),
helper.make_node(
"Concat",
["gather_past_k_output", "k_4d_bnsh"],
["oup_cache_k"],
"concat_past_k_and_curr_k",
axis=2,
),
helper.make_node(
"Transpose",
["oup_cache_k"],
["k_output_transpose"],
"k_transpose_last_two_dims",
perm=[0, 1, 3, 2],
),
]
v_nodes = [
helper.make_node(
"MatMul",
["layernorm_add_output_to_matmul", "v_weight"],
["v_matmul_output"],
"v_path_matmul",
),
helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"),
helper.make_node("Reshape", ["v_add_output", "v_attn_heads_output"], ["v_4d_bsnh"], "v_reshape_to_4d"),
helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]),
helper.make_node(
"Concat",
["gather_past_v_output", "v_4d_bnsh"],
["oup_cache_v"],
"concat_past_v_and_curr_v",
axis=2,
),
]
pos_embed = [
helper.make_node("Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb"),
helper.make_node(
"Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2]
),
helper.make_node(
"MatMul",
["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"],
["pos_matmul"],
"pos_embed_matmul",
),
helper.make_node(
"Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2]
),
helper.make_node(
"Reshape",
["transpose_pos_matmul", "position_embed_output"],
["reshape_position_emb"],
"final_reshape_pos_emb",
),
]
nodes.extend(q_nodes)
nodes.extend(k_nodes)
nodes.extend(v_nodes)
nodes.extend(pos_embed)
# Create nodes used with qkv concats, reshapes, and transposes
nodes.extend(
[
helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0),
helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0),
helper.make_node(
"Mul",
["gather_0_output", "num_heads_int"],
["mul_attn_heads_output"],
"mul_num_heads",
),
helper.make_node(
"Unsqueeze",
["mul_attn_heads_output", "unsqueeze_axes_input"],
["unsqueeze_position_embed"],
"unsqueeze_position_embed",
),
helper.make_node(
"Concat",
["unsqueeze_position_embed", "neg_one", "head_size"],
["position_embed_output"],
"position_embed_concat_output",
axis=0,
),
helper.make_node(
"Unsqueeze",
["gather_0_output", "unsqueeze_axes_input"],
["unsqueeze_attn_heads_output"],
"unsqueeze_num_heads",
),
helper.make_node(
"Concat",
["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"],
["q_attn_heads_output"],
"q_num_heads",
axis=0,
),
helper.make_node(
"Concat",
["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"],
["k_attn_heads_output"],
"k_num_heads",
axis=0,
),
helper.make_node(
"Concat",
["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"],
["v_attn_heads_output"],
"v_num_heads",
axis=0,
),
helper.make_node(
"Concat",
["unsqueeze_attn_heads_output", "neg_one", "head_size"],
["bsd_format"],
axis=0,
),
helper.make_node(
"Constant",
inputs=[],
outputs=["q_bsnh_reshape"],
value=numpy_helper.from_array(
np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor"
),
),
]
)
nodes.extend(
[
helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0),
helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0),
]
)
# Compute Q x K'
nodes.extend(
[
helper.make_node(
"MatMul",
[
"q_div_output",
"k_output_transpose",
],
["qk_output"],
"matmul_qk",
)
]
)
# Create nodes for computing softmax(Q x K') x V
nodes.extend(
[
helper.make_node(
"Add",
[
"qk_output",
"reshape_position_emb",
],
["add_qk_output"],
"add_qk",
),
helper.make_node(
"Softmax",
["add_qk_output"],
["softmax_output"],
"softmax_qk",
axis=2,
),
helper.make_node(
"MatMul",
["softmax_output", "oup_cache_v"],
["qkv_output_(num_heads*batch_size,seq_len,head_size)"],
"matmul_qkv",
),
helper.make_node(
"Transpose",
["qkv_output_(num_heads*batch_size,seq_len,head_size)"],
["qkv_bsnh"],
"transpose_bnsh_to_bsnh",
perm=[0, 2, 1, 3],
),
helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"),
]
)
# Create final nodes to conclude attention
nodes.append(
helper.make_node(
"MatMul",
["attn_output", "matmul_after_attn_initializer"],
["matmul_after_attn_output"],
"matmul_after_attn",
),
)
if not fused:
next_sln_inputs = [
"layernorm_add_output_to_skiplayernorm",
"add_after_attn_output",
"layernorm_weight",
"layernorm_bias",
]
nodes.extend(
[
helper.make_node(
"Add",
["add_after_attn_initializer", "matmul_after_attn_output"],
["add_after_attn_output"],
"add_after_attn",
),
helper.make_node(
"SkipLayerNormalization",
next_sln_inputs,
["output_0", "", "", "output_1"],
"next_skiplayernorm",
domain="com.microsoft",
epsilon=epsilon,
),
]
)
else:
next_sln_inputs = [
"matmul_after_attn_output",
"layernorm_add_output_to_skiplayernorm",
"layernorm_weight",
"layernorm_bias",
"add_after_attn_initializer",
]
nodes.append(
helper.make_node(
"SkipLayerNormalization",
next_sln_inputs,
["output_0", "", "", "output_1"],
"SkipLayerNorm_AddBias_0",
domain="com.microsoft",
epsilon=epsilon,
)
)
# Create initializers
v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size])
v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size])
q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size])
q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size])
k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size])
k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size])
qkv_bias = helper.make_tensor(
"Attention_0_qkv_bias",
TensorProto.FLOAT,
[3 * hidden_size],
q_bias_data + k_bias_data + v_bias_data,
)
initializers = [
float_tensor("layernorm_weight", [hidden_size]),
float_tensor("layernorm_bias", [hidden_size]),
float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]),
float_tensor("add_after_attn_initializer", [hidden_size]),
]
# Add Q/K/V weight tensors as initializers
if fused:
initializers.extend([q_weight, k_weight, v_weight])
initializers.extend([q_bias])
initializers.append(qkv_bias)
initializers.extend(
[
numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"),
numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"),
numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"),
numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"),
numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"),
numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"),
numpy_helper.from_array(np.array([0, 0, num_heads, head_size], dtype="int64"), name="q_bsnh_reshape"),
]
)
else:
initializers.extend([q_weight, k_weight, v_weight])
initializers.extend([q_bias, k_bias, v_bias])
initializers.extend(
[
numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"),
numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"),
numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"),
numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"),
numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"),
numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"),
numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"),
numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"),
numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"),
]
)
# Construct graph
graph = helper.make_graph(nodes, "conformer_self_mha_graph", inputs, outputs, initializers, doc_string="conformer")
opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16))
return helper.make_model(graph, opset_imports=(opsetid,))
if __name__ == "__main__":
np.random.seed(2)
num_heads = 8
hidden_size = 512
model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size)
onnx.save(model, "conformer_self_mha.onnx")
model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True)
onnx.save(model, "./test_data/models/conformer/conformer_self_mha_fused.onnx")

View file

@ -0,0 +1,69 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import unittest
import onnx
from conformer_model_generator import create_conformer_attention
from parity_utilities import find_transformers_source
if find_transformers_source():
from fusion_options import FusionOptions
from onnx_model import OnnxModel
from optimizer import optimize_model
else:
from onnxruntime.transformers.fusion_options import FusionOptions
from onnxruntime.transformers.onnx_model import OnnxModel
from onnxruntime.transformers.optimizer import optimize_model
class TestFusion(unittest.TestCase):
def verify_fusion(self, optimized_model, expected_model_filename):
optimized_model.topological_sort(is_deterministic=True)
expected_model_path = os.path.join(
os.path.dirname(__file__), "test_data", "models", "conformer", expected_model_filename
)
print("Expected model path = ", expected_model_path)
expected_model = OnnxModel(onnx.load(expected_model_path))
expected_model.topological_sort(is_deterministic=True)
nodes = optimized_model.model.graph.node
self.assertEqual(len(nodes), len(expected_model.model.graph.node))
for i in range(len(nodes)):
self.assertEqual(nodes[i], expected_model.model.graph.node[i])
for expected_initializer in expected_model.model.graph.initializer:
print("Expected initializer initial = ", expected_initializer.name)
self.assertTrue(
OnnxModel.has_same_value(
optimized_model.get_initializer(expected_initializer.name), expected_initializer
)
)
def test_ct_mha_fusion(self):
num_heads = 8
hidden_size = 512
model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False)
dir = "."
model_path = os.path.join(dir, "conformer_self_mha.onnx")
onnx.save(model, model_path)
options = FusionOptions("conformer")
optimized_model = optimize_model(
model_path,
model_type="conformer",
num_heads=num_heads,
hidden_size=hidden_size,
optimization_options=options,
)
os.remove(model_path)
self.verify_fusion(optimized_model, "conformer_self_mha_fused.onnx")
if __name__ == "__main__":
unittest.main()