mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
[T5 optimization] script fusions and fixes (#14967)
### Description <!-- Describe your changes. --> 1. added script for t5 encoder self attention and t5 decoder self/cross attention fusions. 2. added simplified layernorm fusion for --external_data_format senario. (otherwise relying on ORT optimizer) 3. added rel_pos_bias shape inference code, modified attention/mha shape inference script. 4. reworked graph_topologic_sort() because the currently implementation is not functioning correctly. also added an option to topo-sort the graph in a deterministic way to let tests pass. note: 1. the t5-beamsearch export code is slightly modified. specifically, encoder_hidden_states(ehs) is no longer an input to the t5 decoder since the ehs is not actually used in the graph execution. 2. recent PRs do not add optimizations to t5 on cpu. 3. the fp32 model(encoder and decoder) for t5-small, t5-base and t5-large can get a parity of e-5 and the corresponding beam search models generate same results as pytorch. 4. fp16(mixed-precision) models, however, get a parity around 3e-2 and some has maximum diff a bit over 3e-2. But the beam search models still generate same results as pytorch (based on limited input data) 5. mt-5 model has a parity issue at the moment, even before any optimization. will investigate later. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
parent
59dfcfdce7
commit
0fa00429d5
8 changed files with 730 additions and 145 deletions
|
|
@ -163,6 +163,7 @@ class SymbolicShapeInference:
|
|||
"Reciprocal": self._pass_on_shape_and_type,
|
||||
"ReduceSum": self._infer_ReduceSum,
|
||||
"ReduceProd": self._infer_ReduceProd,
|
||||
"RelativePositionBias": self._infer_RelativePositionBias,
|
||||
"Reshape": self._infer_Reshape,
|
||||
"Resize": self._infer_Resize,
|
||||
"Round": self._pass_on_shape_and_type,
|
||||
|
|
@ -378,6 +379,17 @@ class SymbolicShapeInference:
|
|||
assert name in self.initializers_
|
||||
return list(self.initializers_[name].dims)
|
||||
|
||||
def _try_get_shape(self, node, idx):
|
||||
if idx > len(node.input) - 1:
|
||||
return None
|
||||
name = node.input[idx]
|
||||
if name in self.known_vi_:
|
||||
vi = self.known_vi_[name]
|
||||
return get_shape_from_value_info(vi)
|
||||
if name in self.initializers_:
|
||||
return list(self.initializers_[name].dims)
|
||||
return None
|
||||
|
||||
def _get_shape_rank(self, node, idx):
|
||||
return len(self._get_shape(node, idx))
|
||||
|
||||
|
|
@ -437,6 +449,7 @@ class SymbolicShapeInference:
|
|||
"GemmFastGelu",
|
||||
"LayerNormalization",
|
||||
"LongformerAttention",
|
||||
"RelativePositionBias",
|
||||
"SimplifiedLayerNormalization",
|
||||
"SkipLayerNormalization",
|
||||
"SkipSimplifiedLayerNormalization",
|
||||
|
|
@ -1495,6 +1508,19 @@ class SymbolicShapeInference:
|
|||
if data is not None:
|
||||
self.sympy_data_[node.output[0]] = sympy_reduce_product(data)
|
||||
|
||||
def _infer_RelativePositionBias(self, node):
|
||||
seq_len = self._try_get_value(node, 1)
|
||||
real_seq_len = self._try_get_value(node, 2)
|
||||
if seq_len is None or real_seq_len is None:
|
||||
return
|
||||
num_heads = self._get_sympy_shape(node, 0)[1]
|
||||
|
||||
new_shape = [1, num_heads, str(seq_len), str(real_seq_len)]
|
||||
|
||||
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
||||
|
||||
def _infer_Reshape(self, node):
|
||||
shape_value = self._try_get_value(node, 1)
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
|
|
@ -2030,14 +2056,18 @@ class SymbolicShapeInference:
|
|||
|
||||
def _infer_Attention(self, node):
|
||||
shape = self._get_shape(node, 0)
|
||||
shape_bias = self._get_shape(node, 2)
|
||||
if shape and len(shape) == 3 and shape_bias and len(shape_bias) == 1:
|
||||
shape_weights = self._get_shape(node, 1)
|
||||
shape_bias = self._try_get_shape(node, 2)
|
||||
if shape_bias is not None:
|
||||
assert len(shape_bias) == 1
|
||||
tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
|
||||
if shape and len(shape) == 3:
|
||||
qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
|
||||
if qkv_hidden_sizes_attr is not None:
|
||||
assert len(qkv_hidden_sizes_attr) == 3
|
||||
shape[2] = int(qkv_hidden_sizes_attr[2])
|
||||
elif isinstance(shape_bias[0], int):
|
||||
shape[2] = int(shape_bias[0] / 3)
|
||||
elif isinstance(tripled_hidden_size, int):
|
||||
shape[2] = int(tripled_hidden_size / 3)
|
||||
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
|
||||
|
|
@ -2068,8 +2098,8 @@ class SymbolicShapeInference:
|
|||
# Output 0 has shape (batch_size, sequence_length, v_hidden_size)
|
||||
# Q, K and V without packing:
|
||||
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
|
||||
# Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
|
||||
# Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)
|
||||
# Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
|
||||
# Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
|
||||
# Packed KV:
|
||||
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
|
||||
# Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size)
|
||||
|
|
@ -2080,29 +2110,65 @@ class SymbolicShapeInference:
|
|||
# Input 2 nullptr
|
||||
|
||||
query_shape = self._get_shape(node, 0)
|
||||
total_sequence_length = None
|
||||
output_dtype = None
|
||||
if query_shape is not None:
|
||||
if len(query_shape) == 3:
|
||||
key_shape = self._get_shape(node, 1)
|
||||
key_shape = self._try_get_shape(node, 1)
|
||||
# By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
|
||||
output_shape = query_shape
|
||||
if key_shape and len(key_shape) == 3:
|
||||
value_shape = self._get_shape(node, 2)
|
||||
if value_shape and len(value_shape) == 3:
|
||||
if key_shape is not None and len(key_shape) == 3:
|
||||
value_shape = self._try_get_shape(node, 2)
|
||||
if value_shape is not None and len(value_shape) == 3:
|
||||
output_shape[2] = value_shape[2]
|
||||
total_sequence_length = key_shape[1]
|
||||
|
||||
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
||||
|
||||
elif len(query_shape) == 5:
|
||||
if isinstance(query_shape[2], int) and isinstance(query_shape[4], int):
|
||||
output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]]
|
||||
else:
|
||||
output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"]
|
||||
|
||||
total_sequence_length = query_shape[1]
|
||||
|
||||
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
||||
|
||||
if len(node.output) > 1:
|
||||
batch_size = query_shape[0]
|
||||
num_heads = get_attribute(node, "num_heads")
|
||||
|
||||
head_size = None
|
||||
if len(query_shape) == 3:
|
||||
head_size = (
|
||||
int(query_shape[2] / num_heads)
|
||||
if isinstance(query_shape[2], int)
|
||||
else f"{query_shape[2]}/{num_heads}"
|
||||
)
|
||||
else:
|
||||
head_size = query_shape[4]
|
||||
|
||||
past_shape = self._try_get_shape(node, 6)
|
||||
|
||||
if past_shape is not None:
|
||||
if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int):
|
||||
total_sequence_length = past_shape[2] + total_sequence_length
|
||||
else:
|
||||
total_sequence_length = f"{past_shape[2]}+{total_sequence_length}"
|
||||
|
||||
present_shape = [batch_size, num_heads, total_sequence_length, head_size]
|
||||
|
||||
assert output_dtype is not None
|
||||
vi = self.known_vi_[node.output[1]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
||||
vi = self.known_vi_[node.output[2]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
||||
|
||||
def _infer_FastGelu(self, node):
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
|
|
@ -2140,8 +2206,6 @@ class SymbolicShapeInference:
|
|||
|
||||
def _infer_SkipLayerNormalization(self, node):
|
||||
self._propagate_shape_and_type(node)
|
||||
if len(node.output) > 3:
|
||||
self._propagate_shape_and_type(node, 0, 3)
|
||||
|
||||
# If the SkipLayerNormalization node contains the optional
|
||||
# output for inference, infer the shape and type for it too
|
||||
|
|
@ -2348,7 +2412,9 @@ class SymbolicShapeInference:
|
|||
for i_o in range(len(node.output)):
|
||||
# Special case: We do not care about the training related
|
||||
# outputs of SkipLayerNormalization
|
||||
if node.op_type == "SkipLayerNormalization" and i_o in [1, 2]:
|
||||
if (
|
||||
node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization"
|
||||
) and i_o in [1, 2]:
|
||||
continue
|
||||
|
||||
vi = self.known_vi_[node.output[i_o]]
|
||||
|
|
|
|||
|
|
@ -706,13 +706,12 @@ def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
|
|||
float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
|
||||
|
||||
input_count = len(graph.input)
|
||||
layer_count = (input_count - 3) // 4
|
||||
layer_count = (input_count - 2) // 4
|
||||
assert layer_count >= 1
|
||||
|
||||
# Expect inputs:
|
||||
# input_ids: int32 (B, 1)
|
||||
# encoder_attention_mask: int32 (B, encode_sequence_length)
|
||||
# encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
|
||||
|
||||
# past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
|
||||
# past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
|
||||
|
|
@ -723,7 +722,7 @@ def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
|
|||
# ... (for each cross attention layer)
|
||||
|
||||
# TODO: encoder_hidden_states is optional
|
||||
expected_inputs = ["input_ids", "encoder_attention_mask", "encoder_hidden_states"]
|
||||
expected_inputs = ["input_ids", "encoder_attention_mask"]
|
||||
for i in range(layer_count):
|
||||
expected_inputs.append(f"past_key_self_{i}")
|
||||
expected_inputs.append(f"past_value_self_{i}")
|
||||
|
|
@ -1406,6 +1405,43 @@ def generate_gpt2_init_decoder(
|
|||
return True
|
||||
|
||||
|
||||
def make_dim_proto_numeric_t5(model, config):
|
||||
"""Make dim_proto numeric.
|
||||
|
||||
Args:
|
||||
model: T5 encoder and decoder model.
|
||||
config: T5 config.
|
||||
"""
|
||||
sequence_length = str(1)
|
||||
num_heads = str(config.num_heads)
|
||||
hidden_size = str(config.d_model)
|
||||
head_size = str(config.d_kv)
|
||||
|
||||
for tensor in model.graph.output:
|
||||
for dim_proto in tensor.type.tensor_type.shape.dim:
|
||||
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
|
||||
sequence_length,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
head_size,
|
||||
]:
|
||||
dim_value = int(dim_proto.dim_param)
|
||||
dim_proto.Clear()
|
||||
dim_proto.dim_value = dim_value
|
||||
|
||||
for tensor in model.graph.input:
|
||||
for dim_proto in tensor.type.tensor_type.shape.dim:
|
||||
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
|
||||
sequence_length,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
head_size,
|
||||
]:
|
||||
dim_value = int(dim_proto.dim_param)
|
||||
dim_proto.Clear()
|
||||
dim_proto.dim_value = dim_value
|
||||
|
||||
|
||||
def convert_generation_model(args: argparse.Namespace, generation_type: GenerationType = GenerationType.BEAMSEARCH):
|
||||
"""Convert model according to command line arguments.
|
||||
|
||||
|
|
@ -1686,6 +1722,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
# )
|
||||
# initializers.extend(moved_initializers)
|
||||
|
||||
make_dim_proto_numeric_t5(encoder_model, config)
|
||||
make_dim_proto_numeric_t5(decoder_model, config)
|
||||
|
||||
node.attribute.extend(
|
||||
[
|
||||
onnx.helper.make_attribute("encoder", encoder_model.graph),
|
||||
|
|
|
|||
|
|
@ -2,11 +2,8 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from enum import Enum
|
||||
from logging import getLogger
|
||||
from os import name
|
||||
from sys import path
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from fusion_base import Fusion
|
||||
|
|
@ -14,7 +11,6 @@ from fusion_options import AttentionMaskFormat
|
|||
from fusion_utils import FusionUtils, NumpyHelper
|
||||
from onnx import NodeProto, TensorProto, helper, numpy_helper
|
||||
from onnx_model import OnnxModel
|
||||
from shape_infer_helper import SymbolicShapeInferenceHelper, get_shape_from_type_proto
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
@ -94,9 +90,10 @@ class FusionAttention(Fusion):
|
|||
num_heads: int,
|
||||
attention_mask: AttentionMask,
|
||||
use_multi_head_attention: bool = False,
|
||||
search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"],
|
||||
):
|
||||
attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
|
||||
super().__init__(model, attention_op_name, ["SkipLayerNormalization", "LayerNormalization"])
|
||||
super().__init__(model, attention_op_name, search_op_types)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.attention_mask = attention_mask
|
||||
|
|
@ -211,6 +208,7 @@ class FusionAttention(Fusion):
|
|||
input: str,
|
||||
output: str,
|
||||
add_qk_str: str,
|
||||
scale: Optional[float] = None,
|
||||
) -> Union[NodeProto, None]:
|
||||
"""Create an Attention node.
|
||||
|
||||
|
|
@ -236,12 +234,22 @@ class FusionAttention(Fusion):
|
|||
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
||||
return None
|
||||
|
||||
has_bias = True
|
||||
if q_add is None and k_add is None and v_add is None:
|
||||
has_bias = False
|
||||
|
||||
q_weight = self.model.get_initializer(q_matmul.input[1])
|
||||
k_weight = self.model.get_initializer(k_matmul.input[1])
|
||||
v_weight = self.model.get_initializer(v_matmul.input[1])
|
||||
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
|
||||
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
|
||||
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
|
||||
|
||||
q_bias, k_bias, v_bias = None, None, None
|
||||
if has_bias:
|
||||
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
|
||||
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
|
||||
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
|
||||
|
||||
if not (k_weight and v_weight and q_bias and k_bias):
|
||||
return None
|
||||
|
||||
if q_weight is None:
|
||||
print(
|
||||
|
|
@ -249,8 +257,6 @@ class FusionAttention(Fusion):
|
|||
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
|
||||
)
|
||||
return None
|
||||
if not (k_weight and v_weight and q_bias and k_bias):
|
||||
return None
|
||||
|
||||
qw = NumpyHelper.to_array(q_weight)
|
||||
kw = NumpyHelper.to_array(k_weight)
|
||||
|
|
@ -290,24 +296,25 @@ class FusionAttention(Fusion):
|
|||
qkv_weight = np.stack((qw, kw, vw), axis=1)
|
||||
qkv_weight_dim = 3 * qw_out_size
|
||||
|
||||
qb = NumpyHelper.to_array(q_bias)
|
||||
kb = NumpyHelper.to_array(k_bias)
|
||||
vb = NumpyHelper.to_array(v_bias)
|
||||
if has_bias:
|
||||
qb = NumpyHelper.to_array(q_bias)
|
||||
kb = NumpyHelper.to_array(k_bias)
|
||||
vb = NumpyHelper.to_array(v_bias)
|
||||
|
||||
q_bias_shape = np.prod(qb.shape)
|
||||
k_bias_shape = np.prod(kb.shape)
|
||||
v_bias_shape = np.prod(vb.shape)
|
||||
q_bias_shape = np.prod(qb.shape)
|
||||
k_bias_shape = np.prod(kb.shape)
|
||||
v_bias_shape = np.prod(vb.shape)
|
||||
|
||||
assert q_bias_shape == k_bias_shape == qw_out_size
|
||||
assert v_bias_shape == vw_out_size
|
||||
assert q_bias_shape == k_bias_shape == qw_out_size
|
||||
assert v_bias_shape == vw_out_size
|
||||
|
||||
qkv_bias_dim = 0
|
||||
if is_qkv_diff_dims:
|
||||
qkv_bias = np.concatenate((qb, kb, vb), axis=0)
|
||||
qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
|
||||
else:
|
||||
qkv_bias = np.stack((qb, kb, vb), axis=0)
|
||||
qkv_bias_dim = 3 * q_bias_shape
|
||||
qkv_bias_dim = 0
|
||||
if is_qkv_diff_dims:
|
||||
qkv_bias = np.concatenate((qb, kb, vb), axis=0)
|
||||
qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
|
||||
else:
|
||||
qkv_bias = np.stack((qb, kb, vb), axis=0)
|
||||
qkv_bias_dim = 3 * q_bias_shape
|
||||
|
||||
attention_node_name = self.model.create_node_name("Attention")
|
||||
|
||||
|
|
@ -324,15 +331,17 @@ class FusionAttention(Fusion):
|
|||
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
|
||||
self.model.add_initializer(weight, self.this_graph_name)
|
||||
|
||||
bias = helper.make_tensor(
|
||||
name=attention_node_name + "_qkv_bias",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[qkv_bias_dim],
|
||||
vals=qkv_bias.flatten().tolist(),
|
||||
)
|
||||
if q_bias.data_type == 10:
|
||||
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
|
||||
self.model.add_initializer(bias, self.this_graph_name)
|
||||
bias = None
|
||||
if has_bias:
|
||||
bias = helper.make_tensor(
|
||||
name=attention_node_name + "_qkv_bias",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[qkv_bias_dim],
|
||||
vals=qkv_bias.flatten().tolist(),
|
||||
)
|
||||
if q_bias.data_type == 10:
|
||||
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
|
||||
self.model.add_initializer(bias, self.this_graph_name)
|
||||
|
||||
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
|
||||
if self.use_multi_head_attention:
|
||||
|
|
@ -359,7 +368,7 @@ class FusionAttention(Fusion):
|
|||
attention_inputs = [
|
||||
input,
|
||||
attention_node_name + "_qkv_weight",
|
||||
attention_node_name + "_qkv_bias",
|
||||
attention_node_name + "_qkv_bias" if has_bias else "",
|
||||
]
|
||||
if mask_index is not None:
|
||||
attention_inputs.append(mask_index)
|
||||
|
|
@ -379,6 +388,9 @@ class FusionAttention(Fusion):
|
|||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
||||
|
||||
if scale is not None:
|
||||
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
|
||||
|
||||
if is_qkv_diff_dims:
|
||||
attention_node.attribute.extend(
|
||||
[helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
|
||||
|
|
|
|||
|
|
@ -92,14 +92,16 @@ class T5Decoder(torch.nn.Module):
|
|||
self.lm_head = lm_head
|
||||
self.config = config
|
||||
|
||||
def forward(self, decoder_input_ids, encoder_attention_mask, encoder_hidden_states, *past):
|
||||
def forward(self, decoder_input_ids, encoder_attention_mask, *past):
|
||||
|
||||
past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers)
|
||||
|
||||
# This is a hack since only the third dimension of encoder_hidden_states is used here
|
||||
dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2)
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
past_key_values=past_key_values,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states=dummy_encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
|
|
@ -122,12 +124,10 @@ class T5DecoderInputs:
|
|||
self,
|
||||
decoder_input_ids,
|
||||
encoder_attention_mask,
|
||||
encoder_hidden_states,
|
||||
past_key_values=None,
|
||||
):
|
||||
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
|
||||
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
|
||||
self.encoder_hidden_states: Union[torch.FloatTensor, torch.HalfTensor] = encoder_hidden_states
|
||||
self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -181,13 +181,6 @@ class T5DecoderInputs:
|
|||
)
|
||||
|
||||
float_type = torch.float16 if float16 else torch.float32
|
||||
encoder_hidden_state = torch.rand(
|
||||
batch_size,
|
||||
encode_sequence_length,
|
||||
hidden_size,
|
||||
dtype=float_type,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if past_decode_sequence_length > 0:
|
||||
self_attention_past_shape = [
|
||||
|
|
@ -212,25 +205,22 @@ class T5DecoderInputs:
|
|||
else:
|
||||
past = None
|
||||
|
||||
return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, encoder_hidden_state, past)
|
||||
return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past)
|
||||
|
||||
def to_list(self) -> List:
|
||||
input_list = [
|
||||
self.decoder_input_ids,
|
||||
self.encoder_attention_mask,
|
||||
self.encoder_hidden_states,
|
||||
]
|
||||
if self.past_key_values:
|
||||
input_list.extend(self.past_key_values)
|
||||
return input_list
|
||||
|
||||
def to_fp32(self):
|
||||
encoder_hidden_state = self.encoder_hidden_states.to(dtype=torch.float32)
|
||||
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
|
||||
return T5DecoderInputs(
|
||||
self.decoder_input_ids.clone(),
|
||||
self.encoder_attention_mask.clone(),
|
||||
encoder_hidden_state,
|
||||
past,
|
||||
)
|
||||
|
||||
|
|
@ -278,7 +268,6 @@ class T5DecoderHelper:
|
|||
# Shape of input tensors (sequence_length==1):
|
||||
# input_ids: (batch_size, sequence_length)
|
||||
# encoder_attention_mask: (batch_size, encode_sequence_length)
|
||||
# encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
|
||||
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
|
||||
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
||||
|
||||
|
|
@ -289,7 +278,6 @@ class T5DecoderHelper:
|
|||
|
||||
input_names = ["input_ids"]
|
||||
input_names.append("encoder_attention_mask")
|
||||
input_names.append("encoder_hidden_states")
|
||||
input_names.extend(input_past_names)
|
||||
|
||||
dynamic_axes = {
|
||||
|
|
@ -362,7 +350,6 @@ class T5DecoderHelper:
|
|||
ort_inputs = {
|
||||
"input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
|
||||
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
|
||||
"encoder_hidden_states": numpy.ascontiguousarray(inputs.encoder_hidden_states.cpu().numpy()),
|
||||
}
|
||||
|
||||
if inputs.past_key_values:
|
||||
|
|
@ -384,7 +371,7 @@ class T5DecoderHelper:
|
|||
max_cases: int = 4,
|
||||
):
|
||||
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
||||
float16: bool = TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)"
|
||||
float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
|
||||
|
||||
test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
|
||||
test_cases_max_diff = []
|
||||
|
|
|
|||
|
|
@ -151,21 +151,17 @@ class T5Helper:
|
|||
def auto_mixed_precision(
|
||||
onnx_model: OnnxModel,
|
||||
op_block_list: List[str] = [
|
||||
"Pow",
|
||||
"ReduceMean",
|
||||
"Add",
|
||||
"Sqrt",
|
||||
"Div",
|
||||
"Mul",
|
||||
"Softmax",
|
||||
"SimplifiedLayerNormalization",
|
||||
"SkipSimplifiedLayerNormalization",
|
||||
"Relu",
|
||||
"Add",
|
||||
],
|
||||
):
|
||||
"""Convert model to mixed precision.
|
||||
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
|
||||
Args:
|
||||
onnx_model (OnnxModel): optimized ONNX model
|
||||
op_block_list (List[str], optional): . Defaults to ["Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Softmax", "Relu"]
|
||||
op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
|
||||
Returns:
|
||||
parameters(dict): a dictionary of parameters used in float16 conversion
|
||||
"""
|
||||
|
|
@ -235,8 +231,7 @@ class T5Helper:
|
|||
from fusion_options import FusionOptions
|
||||
|
||||
optimization_options = None
|
||||
if not use_gpu:
|
||||
# Currently there is no SkipSimplifiedLayerNorm cpu kernel
|
||||
if is_float16:
|
||||
optimization_options = FusionOptions("t5")
|
||||
optimization_options.enable_skip_layer_norm = False
|
||||
|
||||
|
|
@ -245,10 +240,12 @@ class T5Helper:
|
|||
model_type="t5",
|
||||
num_heads=num_attention_heads,
|
||||
hidden_size=hidden_size,
|
||||
opt_level=2 if not is_float16 and not use_external_data_format else 0,
|
||||
opt_level=2 if not use_external_data_format else 0,
|
||||
optimization_options=optimization_options,
|
||||
use_gpu=False,
|
||||
only_onnxruntime=not use_gpu,
|
||||
)
|
||||
|
||||
if is_float16:
|
||||
if auto_mixed_precision:
|
||||
T5Helper.auto_mixed_precision(m)
|
||||
|
|
|
|||
|
|
@ -615,7 +615,7 @@ class OnnxModel:
|
|||
if use_symbolic_shape_infer:
|
||||
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
|
||||
# are not recognized by onnx shape inference.
|
||||
shape_infer_helper = SymbolicShapeInferenceHelper(model)
|
||||
shape_infer_helper = SymbolicShapeInferenceHelper(model, verbose=0)
|
||||
model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False)
|
||||
|
||||
parameters = {"disable_shape_infer": use_symbolic_shape_infer}
|
||||
|
|
@ -876,66 +876,64 @@ class OnnxModel:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
def graph_topological_sort(graph):
|
||||
deps_count = [0] * len(graph.node) # dependency count of each node
|
||||
deps_to_nodes = {} # input to node indice
|
||||
def graph_topological_sort(graph, is_deterministic=False):
|
||||
deps_set = set() # dependency set of all node
|
||||
sorted_node_set = set() # sorted node set
|
||||
sorted_nodes = [] # initialize sorted_nodes
|
||||
for node_idx, node in enumerate(graph.node):
|
||||
# CANNOT use len(node.input) directly because input can be optional
|
||||
deps_count[node_idx] = sum(1 for _ in node.input if _)
|
||||
if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
|
||||
sorted_nodes.append(graph.node[node_idx])
|
||||
continue
|
||||
|
||||
for input_name in node.input:
|
||||
if input_name not in deps_to_nodes:
|
||||
deps_to_nodes[input_name] = [node_idx]
|
||||
else:
|
||||
deps_to_nodes[input_name].append(node_idx)
|
||||
|
||||
# Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph
|
||||
initializer_names = [init.name for init in graph.initializer]
|
||||
graph_input_names = [input.name for input in graph.input]
|
||||
input_names = initializer_names + graph_input_names
|
||||
input_names.sort()
|
||||
prev_input_name = None
|
||||
|
||||
if is_deterministic:
|
||||
input_names.sort()
|
||||
|
||||
for input_name in input_names:
|
||||
if prev_input_name == input_name:
|
||||
continue
|
||||
deps_set.add(input_name)
|
||||
|
||||
prev_input_name = input_name
|
||||
if input_name in deps_to_nodes:
|
||||
for node_idx in deps_to_nodes[input_name]:
|
||||
deps_count[node_idx] = deps_count[node_idx] - 1
|
||||
if deps_count[node_idx] == 0:
|
||||
sorted_nodes.append(graph.node[node_idx])
|
||||
sorted_node_set_len = -1
|
||||
graph_nodes = graph.node if not is_deterministic else sorted(graph.node, key=lambda x: x.name)
|
||||
last_node_name = None
|
||||
while len(sorted_node_set) != len(graph_nodes):
|
||||
if len(sorted_node_set) == sorted_node_set_len:
|
||||
break
|
||||
sorted_node_set_len = len(sorted_node_set)
|
||||
for node_idx, node in enumerate(graph_nodes):
|
||||
if node_idx in sorted_node_set:
|
||||
continue
|
||||
input_count = sum(1 for _ in node.input if _)
|
||||
if input_count == 0:
|
||||
sorted_nodes.append(node)
|
||||
sorted_node_set.add(node_idx)
|
||||
for output in node.output:
|
||||
deps_set.add(output)
|
||||
continue
|
||||
failed = False
|
||||
for input_name in node.input:
|
||||
if input_name != "" and input_name not in deps_set:
|
||||
failed = True
|
||||
last_node_name = node.name
|
||||
if not failed:
|
||||
sorted_nodes.append(node)
|
||||
sorted_node_set.add(node_idx)
|
||||
for output in node.output:
|
||||
deps_set.add(output)
|
||||
else:
|
||||
continue
|
||||
|
||||
start = 0
|
||||
end = len(sorted_nodes)
|
||||
|
||||
while start < end:
|
||||
for output in sorted_nodes[start].output:
|
||||
if output in deps_to_nodes:
|
||||
for node_idx in deps_to_nodes[output]:
|
||||
deps_count[node_idx] = deps_count[node_idx] - 1
|
||||
if deps_count[node_idx] == 0:
|
||||
sorted_nodes.append(graph.node[node_idx])
|
||||
end = end + 1
|
||||
start = start + 1
|
||||
|
||||
if end != len(graph.node):
|
||||
if len(sorted_node_set) != len(graph.node):
|
||||
raise RuntimeError(
|
||||
f"Graph is not a DAG: end={end}, len(graph.node)={len(graph.node)}, graph.node[end]={graph.node[end]}"
|
||||
f"Graph is not a DAG: len(sorted_node_set)={len(sorted_node_set)}, len(graph.node)={len(graph.node)}, failed at node {last_node_name}"
|
||||
)
|
||||
|
||||
graph.ClearField("node")
|
||||
graph.node.extend(sorted_nodes)
|
||||
|
||||
def topological_sort(self):
|
||||
def topological_sort(self, is_deterministic=False):
|
||||
# TODO: support graph_topological_sort() in subgraphs
|
||||
# for graph in self.graphs():
|
||||
# self.graph_topological_sort(graph)
|
||||
OnnxModel.graph_topological_sort(self.model.graph)
|
||||
OnnxModel.graph_topological_sort(self.model.graph, is_deterministic)
|
||||
|
||||
@staticmethod
|
||||
def save(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Union
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
from fusion_attention import AttentionMask, FusionAttention
|
||||
|
|
@ -16,7 +16,7 @@ from onnx_model_bert import BertOnnxModel
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: Support decoder self/cross attention fusion and encoder self attention fusion
|
||||
|
||||
class FusionT5Attention(FusionAttention):
|
||||
"""
|
||||
Fuse T5 Attention subgraph into one Attention node.
|
||||
|
|
@ -29,25 +29,460 @@ class FusionT5Attention(FusionAttention):
|
|||
num_heads: int,
|
||||
attention_mask: AttentionMask,
|
||||
):
|
||||
super().__init__(model, hidden_size, num_heads, attention_mask)
|
||||
super().__init__(
|
||||
model,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
attention_mask,
|
||||
use_multi_head_attention=False,
|
||||
search_op_types=["SkipSimplifiedLayerNormalization", "Add"],
|
||||
)
|
||||
self.static_kv = 1
|
||||
|
||||
def create_attention_node(
|
||||
def create_mha_node(
|
||||
self,
|
||||
query: str,
|
||||
key: str,
|
||||
value: str,
|
||||
mask_index: str,
|
||||
matmul: NodeProto,
|
||||
add: NodeProto,
|
||||
res_pos_bias: str,
|
||||
past_key: str,
|
||||
past_value: str,
|
||||
output: str,
|
||||
present_key: str,
|
||||
present_value: str,
|
||||
num_heads: int,
|
||||
hidden_size: int,
|
||||
input: str,
|
||||
output: str,
|
||||
add_qk_str: str,
|
||||
) -> Union[NodeProto, None]:
|
||||
# Not implemented yet
|
||||
return None
|
||||
|
||||
assert num_heads > 0
|
||||
|
||||
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
||||
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
||||
return None
|
||||
|
||||
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
||||
attention_inputs = [
|
||||
query,
|
||||
"" if key is None else key, # key
|
||||
"" if value is None else value, # value
|
||||
"", # bias
|
||||
]
|
||||
if mask_index is not None:
|
||||
attention_inputs.append(mask_index)
|
||||
else:
|
||||
attention_inputs.append("")
|
||||
|
||||
if res_pos_bias is not None:
|
||||
attention_inputs.append(res_pos_bias)
|
||||
else:
|
||||
attention_inputs.append("")
|
||||
|
||||
if past_key is not None:
|
||||
assert past_value is not None
|
||||
attention_inputs.append(past_key)
|
||||
attention_inputs.append(past_value)
|
||||
|
||||
attention_outputs = [output]
|
||||
if present_key is not None:
|
||||
assert present_value is not None
|
||||
attention_outputs.append(present_key)
|
||||
attention_outputs.append(present_value)
|
||||
|
||||
attention_node = helper.make_node(
|
||||
"MultiHeadAttention",
|
||||
inputs=attention_inputs,
|
||||
outputs=attention_outputs,
|
||||
name=attention_node_name,
|
||||
)
|
||||
|
||||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
||||
attention_node.attribute.extend([helper.make_attribute("scale", 1.0)])
|
||||
if self.mask_filter_value is not None:
|
||||
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
||||
|
||||
self.increase_counter("MultiHeadAttention")
|
||||
return attention_node
|
||||
|
||||
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
||||
# Not implemented yet
|
||||
return
|
||||
self.fuse_t5_encoder(normalize_node, input_name_to_nodes, output_name_to_node)
|
||||
self.fuse_t5_decoder(normalize_node, input_name_to_nodes, output_name_to_node)
|
||||
|
||||
def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
||||
if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add":
|
||||
return
|
||||
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
normalize_node,
|
||||
["MatMul", "Reshape", "Transpose", "MatMul"],
|
||||
[1, 0, 0, 0],
|
||||
)
|
||||
if qkv_nodes is None:
|
||||
return
|
||||
|
||||
_, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes
|
||||
|
||||
qkv_shape_nodes = self.model.match_parent_path(
|
||||
reshape_qkv,
|
||||
["Concat", "Unsqueeze", "Gather", "Shape"],
|
||||
[1, 0, 0, 0],
|
||||
)
|
||||
if qkv_shape_nodes is None:
|
||||
return
|
||||
input_shape_node = qkv_shape_nodes[-1]
|
||||
|
||||
v_nodes = self.model.match_parent_path(
|
||||
matmul_qkv,
|
||||
["Transpose", "Reshape", "MatMul"],
|
||||
[1, 0, 0],
|
||||
)
|
||||
if v_nodes is None:
|
||||
return
|
||||
_, reshape_v, matmul_v = v_nodes
|
||||
# todo: check reshape_v parent nodes
|
||||
|
||||
qk_nodes = self.model.match_parent_path(
|
||||
matmul_qkv,
|
||||
["Softmax", "Add", "MatMul"],
|
||||
[0, 0, 0],
|
||||
)
|
||||
if qk_nodes is None:
|
||||
return
|
||||
_, add_qk, matmul_qk = qk_nodes
|
||||
|
||||
mask_index = None
|
||||
mask_nodes = self.model.match_parent_path(
|
||||
add_qk,
|
||||
["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
|
||||
[1, 1, 0, 1, 0, 0],
|
||||
)
|
||||
if mask_nodes is None:
|
||||
return
|
||||
mul_node = mask_nodes[1]
|
||||
if mask_nodes[1].op_type != "Mul":
|
||||
return
|
||||
|
||||
_, mul_val = self.model.get_constant_input(mul_node)
|
||||
if mul_val != -10000:
|
||||
self.mask_filter_value = mul_val
|
||||
|
||||
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
||||
|
||||
res_pos_bias = None
|
||||
rpb_nodes = self.model.match_parent_path(
|
||||
add_qk,
|
||||
["Add", "RelativePositionBias"],
|
||||
[1, 0],
|
||||
)
|
||||
if rpb_nodes is None:
|
||||
return
|
||||
rpb_add_node = rpb_nodes[0]
|
||||
res_pos_bias = rpb_add_node.input[0]
|
||||
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk,
|
||||
["Transpose", "Reshape", "MatMul"],
|
||||
[1, 0, 0],
|
||||
)
|
||||
if k_nodes is None:
|
||||
return
|
||||
_, reshape_k, matmul_k = k_nodes
|
||||
# todo: check reshape_k parent nodes
|
||||
|
||||
q_nodes = self.model.match_parent_path(
|
||||
matmul_qk,
|
||||
["Transpose", "Reshape", "MatMul"],
|
||||
[0, 0, 0],
|
||||
)
|
||||
if q_nodes is None:
|
||||
return
|
||||
|
||||
transpose_q, reshape_q, matmul_q = q_nodes
|
||||
# todo: check reshape_q parent nodes
|
||||
|
||||
if matmul_q.input[0] != input_shape_node.input[0]:
|
||||
return
|
||||
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
||||
|
||||
new_node = self.create_attention_node(
|
||||
mask_index,
|
||||
matmul_q,
|
||||
matmul_k,
|
||||
matmul_v,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
q_num_heads,
|
||||
q_hidden_size,
|
||||
input_shape_node.input[0],
|
||||
reshape_qkv.output[0],
|
||||
res_pos_bias,
|
||||
1.0,
|
||||
)
|
||||
if new_node is None:
|
||||
return
|
||||
|
||||
self.nodes_to_add.append(new_node)
|
||||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
|
||||
self.nodes_to_remove.extend(qkv_nodes[1:])
|
||||
self.nodes_to_remove.extend(qk_nodes)
|
||||
self.nodes_to_remove.extend(k_nodes[:-1])
|
||||
if v_nodes is not None:
|
||||
self.nodes_to_remove.extend(v_nodes[:-1])
|
||||
self.nodes_to_remove.extend(q_nodes[:-1])
|
||||
|
||||
self.prune_graph = True
|
||||
|
||||
def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
||||
if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add":
|
||||
return
|
||||
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
normalize_node,
|
||||
["MatMul", "Reshape", "Transpose", "MatMul"],
|
||||
[1, 0, 0, 0],
|
||||
)
|
||||
if qkv_nodes is None:
|
||||
return
|
||||
|
||||
_, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes
|
||||
|
||||
qkv_shape_nodes = self.model.match_parent_path(
|
||||
reshape_qkv,
|
||||
["Concat", "Unsqueeze", "Gather", "Shape"],
|
||||
[1, 0, 0, 0],
|
||||
)
|
||||
if qkv_shape_nodes is None:
|
||||
return
|
||||
input_shape_node = qkv_shape_nodes[-1]
|
||||
|
||||
value = None
|
||||
past_value = None
|
||||
present_value = None
|
||||
v_nodes = self.model.match_parent_path(
|
||||
matmul_qkv,
|
||||
["Concat", "Transpose", "Reshape", "MatMul"],
|
||||
[1, 1, 0, 0],
|
||||
)
|
||||
if v_nodes is None:
|
||||
v_nodes = self.model.match_parent_path(
|
||||
matmul_qkv,
|
||||
["Transpose", "Reshape", "MatMul"],
|
||||
[1, 0, 0],
|
||||
)
|
||||
if v_nodes is not None:
|
||||
transpose_v, reshape_v, matmul_v = v_nodes
|
||||
value = reshape_v.input[0]
|
||||
present_value = transpose_v.output[0]
|
||||
if "present_value" not in present_value:
|
||||
return
|
||||
if matmul_v.input[0] != input_shape_node.input[0]:
|
||||
self.static_kv = 1
|
||||
else:
|
||||
self.static_kv = 0
|
||||
else:
|
||||
past_value = matmul_qkv.input[1]
|
||||
if past_value in output_name_to_node:
|
||||
return
|
||||
if "past_value_cross" not in past_value:
|
||||
return
|
||||
self.static_kv = 1
|
||||
else:
|
||||
concat_v, _, reshape_v, _ = v_nodes
|
||||
past_value = concat_v.input[0]
|
||||
if past_value in output_name_to_node:
|
||||
return
|
||||
if "past_value_self" not in past_value:
|
||||
return
|
||||
present_value = concat_v.output[0]
|
||||
if "present_value_self" not in present_value:
|
||||
return
|
||||
value = reshape_v.input[0]
|
||||
self.static_kv = 0
|
||||
|
||||
qk_nodes = self.model.match_parent_path(
|
||||
matmul_qkv,
|
||||
["Softmax", "Add", "MatMul"],
|
||||
[0, 0, 0],
|
||||
)
|
||||
if qk_nodes is None:
|
||||
return
|
||||
_, add_qk, matmul_qk = qk_nodes
|
||||
|
||||
mask_index = None
|
||||
res_pos_bias = None
|
||||
if self.static_kv == 1:
|
||||
mask_nodes = self.model.match_parent_path(
|
||||
add_qk,
|
||||
["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
|
||||
[1, 1, 0, 1, 0, 0],
|
||||
)
|
||||
if mask_nodes is None:
|
||||
return
|
||||
mul_node = mask_nodes[1]
|
||||
if mask_nodes[1].op_type != "Mul":
|
||||
return
|
||||
|
||||
_, mul_val = self.model.get_constant_input(mul_node)
|
||||
if mul_val != -10000:
|
||||
self.mask_filter_value = mul_val
|
||||
|
||||
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
||||
else:
|
||||
rpb_nodes = self.model.match_parent_path(
|
||||
add_qk,
|
||||
["Add", "Slice"],
|
||||
[1, 0],
|
||||
)
|
||||
if rpb_nodes is not None:
|
||||
res_pos_bias = add_qk.input[1]
|
||||
else:
|
||||
rpb_nodes = self.model.match_parent_path(
|
||||
add_qk,
|
||||
["Add", "RelativePositionBias"],
|
||||
[1, 0],
|
||||
)
|
||||
if rpb_nodes is None:
|
||||
return
|
||||
res_pos_bias = add_qk.input[1]
|
||||
|
||||
key = None
|
||||
past_key = None
|
||||
present_key = None
|
||||
if self.static_kv == 1:
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk,
|
||||
["Transpose", "Reshape", "MatMul"],
|
||||
[1, 0, 0],
|
||||
)
|
||||
if k_nodes is not None:
|
||||
transpose_k, reshape_k, _ = k_nodes
|
||||
key = reshape_k.input[0]
|
||||
present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
|
||||
for present_key_transpose_node in present_key_transpose_nodes:
|
||||
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
|
||||
if present_key_candidate is not None:
|
||||
present_key = present_key_candidate.name
|
||||
break
|
||||
if present_key is None:
|
||||
return
|
||||
if "present_key_cross" not in present_key:
|
||||
return
|
||||
else:
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk,
|
||||
["Transpose"],
|
||||
[1],
|
||||
)
|
||||
if k_nodes is None:
|
||||
return
|
||||
transpose_k = k_nodes[0]
|
||||
|
||||
past_key = transpose_k.input[0]
|
||||
if past_key in output_name_to_node:
|
||||
return
|
||||
if "past_key_cross" not in past_key:
|
||||
return
|
||||
else:
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk,
|
||||
["Transpose", "Concat", "Reshape", "MatMul"],
|
||||
[1, 0, 1, 0],
|
||||
)
|
||||
if k_nodes is not None:
|
||||
_, concat_k, reshape_k, _ = k_nodes
|
||||
key = reshape_k.input[0]
|
||||
past_key_transpose_node = output_name_to_node[concat_k.input[0]]
|
||||
past_key = past_key_transpose_node.input[0]
|
||||
if past_key in output_name_to_node:
|
||||
return
|
||||
if "past_key_self" not in past_key:
|
||||
return
|
||||
present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
|
||||
for present_key_transpose_node in present_key_transpose_nodes:
|
||||
# print("present_key_transpose_node:", present_key_transpose_node)
|
||||
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
|
||||
# print("present_key_candidate:", present_key_candidate)
|
||||
if present_key_candidate is not None:
|
||||
present_key = present_key_candidate.name
|
||||
break
|
||||
if present_key is None:
|
||||
return
|
||||
if "present_key_self" not in present_key:
|
||||
return
|
||||
else:
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk,
|
||||
["Transpose", "Reshape", "MatMul"],
|
||||
[1, 0, 0],
|
||||
)
|
||||
if k_nodes is None:
|
||||
return
|
||||
_, reshape_k, _ = k_nodes
|
||||
key = reshape_k.input[0]
|
||||
present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
|
||||
for present_key_transpose_node in present_key_transpose_nodes:
|
||||
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
|
||||
if present_key_candidate is not None:
|
||||
present_key = present_key_candidate.name
|
||||
break
|
||||
if present_key is None:
|
||||
return
|
||||
if "present_key_self" not in present_key:
|
||||
return
|
||||
|
||||
q_nodes = self.model.match_parent_path(
|
||||
matmul_qk,
|
||||
["Transpose", "Reshape", "MatMul"],
|
||||
[0, 0, 0],
|
||||
)
|
||||
if q_nodes is None:
|
||||
return
|
||||
|
||||
transpose_q, reshape_q, matmul_q = q_nodes
|
||||
|
||||
if matmul_q.input[0] != input_shape_node.input[0]:
|
||||
return
|
||||
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
||||
|
||||
if self.static_kv == 1 and past_key is not None:
|
||||
key = past_key
|
||||
value = past_value
|
||||
past_key = None
|
||||
past_value = None
|
||||
|
||||
new_node = self.create_mha_node(
|
||||
matmul_q.output[0],
|
||||
key,
|
||||
value,
|
||||
mask_index,
|
||||
res_pos_bias,
|
||||
past_key,
|
||||
past_value,
|
||||
reshape_qkv.output[0],
|
||||
present_key,
|
||||
present_value,
|
||||
q_num_heads,
|
||||
q_hidden_size,
|
||||
)
|
||||
if new_node is None:
|
||||
return
|
||||
|
||||
self.nodes_to_add.append(new_node)
|
||||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
|
||||
self.nodes_to_remove.extend(qkv_nodes[1:])
|
||||
self.nodes_to_remove.extend(qk_nodes)
|
||||
self.nodes_to_remove.extend(k_nodes[:-1])
|
||||
if v_nodes is not None:
|
||||
self.nodes_to_remove.extend(v_nodes[:-1])
|
||||
self.nodes_to_remove.extend(q_nodes[:-1])
|
||||
|
||||
self.prune_graph = True
|
||||
|
||||
|
||||
class FusionRelativePositionBiasBlock(Fusion):
|
||||
|
|
@ -135,12 +570,56 @@ class FusionRelativePositionBiasBlock(Fusion):
|
|||
self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name
|
||||
|
||||
|
||||
class FusionSimplifiedLayerNormalization(Fusion):
|
||||
def __init__(self, model: OnnxModel):
|
||||
super().__init__(model, "SimplifiedLayerNormalization", "Mul")
|
||||
|
||||
def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
|
||||
if node.op_type != "Mul":
|
||||
return
|
||||
|
||||
sim_ln_nodes = self.model.match_parent_path(
|
||||
node,
|
||||
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
|
||||
[1, 1, 1, 0, 0, 0, 0],
|
||||
)
|
||||
if sim_ln_nodes is None:
|
||||
return
|
||||
|
||||
pow_node = sim_ln_nodes[-2]
|
||||
if not self.model.find_constant_input(pow_node, 2.0) == 1:
|
||||
return
|
||||
|
||||
root_input = pow_node.input[0]
|
||||
|
||||
mul_node_1 = sim_ln_nodes[0]
|
||||
if root_input != mul_node_1.input[0]:
|
||||
return
|
||||
|
||||
second_add_node = sim_ln_nodes[3]
|
||||
i, add_weight = self.model.get_constant_input(second_add_node)
|
||||
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
|
||||
logger.warning(f"epsilon value is not expeced: {add_weight}")
|
||||
return
|
||||
|
||||
self.nodes_to_remove.extend(sim_ln_nodes[:-1])
|
||||
|
||||
normalize_node = helper.make_node(
|
||||
"SimplifiedLayerNormalization",
|
||||
inputs=[root_input, node.input[0]],
|
||||
outputs=[node.output[0]],
|
||||
name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"),
|
||||
)
|
||||
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
|
||||
normalize_node.attribute.extend([helper.make_attribute("axis", int(-1))])
|
||||
normalize_node.attribute.extend([helper.make_attribute("stash_type", int(1))])
|
||||
self.nodes_to_add.append(normalize_node)
|
||||
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
|
||||
|
||||
|
||||
class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
|
||||
def __init__(self, model: OnnxModel):
|
||||
super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
|
||||
self.shape_infer_helper = self.model.infer_runtime_shape(
|
||||
{"batch_size": 2, "seq_len": 1, "encode_sequence_length": 8, "past_decode_sequence_length": 4}, update=True
|
||||
)
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
super().fuse(node, input_name_to_nodes, output_name_to_node)
|
||||
|
|
@ -151,6 +630,7 @@ class T5OnnxModel(BertOnnxModel):
|
|||
super().__init__(model, num_heads, hidden_size)
|
||||
self.attention_mask = AttentionMask(self)
|
||||
self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
|
||||
self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self)
|
||||
self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
|
||||
# TODO: consider retrive max_distance from model.
|
||||
# math.log(max_distance / (num_buckets // 2))
|
||||
|
|
@ -159,6 +639,9 @@ class T5OnnxModel(BertOnnxModel):
|
|||
def fuse_attention(self):
|
||||
self.attention_fusion.apply()
|
||||
|
||||
def fuse_layer_norm(self):
|
||||
self.layer_norm_fusion.apply()
|
||||
|
||||
def fuse_skip_layer_norm(self):
|
||||
self.skip_layer_norm_fusion.apply()
|
||||
|
||||
|
|
@ -234,8 +717,11 @@ class T5OnnxModel(BertOnnxModel):
|
|||
nodes_to_remove.append(node)
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
|
||||
def postprocess(self):
|
||||
def preprocess(self):
|
||||
self.adjust_reshape_and_expand()
|
||||
self.rpb_fusion.apply()
|
||||
|
||||
def postprocess(self):
|
||||
# remove get_extended_attention_mask() since it generates all zeros.
|
||||
self.remove_extended_mask_decoder_init()
|
||||
self.remove_extended_mask_decoder()
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ else:
|
|||
|
||||
class TestFusion(unittest.TestCase):
|
||||
def verify_fusion(self, optimized_model, expected_model_filename):
|
||||
optimized_model.topological_sort()
|
||||
optimized_model.topological_sort(is_deterministic=True)
|
||||
|
||||
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename)
|
||||
expected_model = OnnxModel(onnx.load(expected_model_path))
|
||||
expected_model.topological_sort()
|
||||
expected_model.topological_sort(is_deterministic=True)
|
||||
|
||||
self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue