[T5 optimization] script fusions and fixes (#14967)

### Description
<!-- Describe your changes. -->

1. added script for t5 encoder self attention and t5 decoder self/cross
attention fusions.
2. added simplified layernorm fusion for --external_data_format senario.
(otherwise relying on ORT optimizer)
3. added rel_pos_bias shape inference code, modified attention/mha shape
inference script.
4. reworked graph_topologic_sort() because the currently implementation
is not functioning correctly. also added an option to topo-sort the
graph in a deterministic way to let tests pass.

note:
1. the t5-beamsearch export code is slightly modified. specifically,
encoder_hidden_states(ehs) is no longer an input to the t5 decoder since
the ehs is not actually used in the graph execution.
2. recent PRs do not add optimizations to t5 on cpu. 
3. the fp32 model(encoder and decoder) for t5-small, t5-base and
t5-large can get a parity of e-5 and the corresponding beam search
models generate same results as pytorch.
4. fp16(mixed-precision) models, however, get a parity around 3e-2 and
some has maximum diff a bit over 3e-2. But the beam search models still
generate same results as pytorch (based on limited input data)
5. mt-5 model has a parity issue at the moment, even before any
optimization. will investigate later.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2023-03-13 23:35:56 -07:00 committed by GitHub
parent 59dfcfdce7
commit 0fa00429d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 730 additions and 145 deletions

View file

@ -163,6 +163,7 @@ class SymbolicShapeInference:
"Reciprocal": self._pass_on_shape_and_type,
"ReduceSum": self._infer_ReduceSum,
"ReduceProd": self._infer_ReduceProd,
"RelativePositionBias": self._infer_RelativePositionBias,
"Reshape": self._infer_Reshape,
"Resize": self._infer_Resize,
"Round": self._pass_on_shape_and_type,
@ -378,6 +379,17 @@ class SymbolicShapeInference:
assert name in self.initializers_
return list(self.initializers_[name].dims)
def _try_get_shape(self, node, idx):
if idx > len(node.input) - 1:
return None
name = node.input[idx]
if name in self.known_vi_:
vi = self.known_vi_[name]
return get_shape_from_value_info(vi)
if name in self.initializers_:
return list(self.initializers_[name].dims)
return None
def _get_shape_rank(self, node, idx):
return len(self._get_shape(node, idx))
@ -437,6 +449,7 @@ class SymbolicShapeInference:
"GemmFastGelu",
"LayerNormalization",
"LongformerAttention",
"RelativePositionBias",
"SimplifiedLayerNormalization",
"SkipLayerNormalization",
"SkipSimplifiedLayerNormalization",
@ -1495,6 +1508,19 @@ class SymbolicShapeInference:
if data is not None:
self.sympy_data_[node.output[0]] = sympy_reduce_product(data)
def _infer_RelativePositionBias(self, node):
seq_len = self._try_get_value(node, 1)
real_seq_len = self._try_get_value(node, 2)
if seq_len is None or real_seq_len is None:
return
num_heads = self._get_sympy_shape(node, 0)[1]
new_shape = [1, num_heads, str(seq_len), str(real_seq_len)]
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
def _infer_Reshape(self, node):
shape_value = self._try_get_value(node, 1)
vi = self.known_vi_[node.output[0]]
@ -2030,14 +2056,18 @@ class SymbolicShapeInference:
def _infer_Attention(self, node):
shape = self._get_shape(node, 0)
shape_bias = self._get_shape(node, 2)
if shape and len(shape) == 3 and shape_bias and len(shape_bias) == 1:
shape_weights = self._get_shape(node, 1)
shape_bias = self._try_get_shape(node, 2)
if shape_bias is not None:
assert len(shape_bias) == 1
tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
if shape and len(shape) == 3:
qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
if qkv_hidden_sizes_attr is not None:
assert len(qkv_hidden_sizes_attr) == 3
shape[2] = int(qkv_hidden_sizes_attr[2])
elif isinstance(shape_bias[0], int):
shape[2] = int(shape_bias[0] / 3)
elif isinstance(tripled_hidden_size, int):
shape[2] = int(tripled_hidden_size / 3)
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
@ -2068,8 +2098,8 @@ class SymbolicShapeInference:
# Output 0 has shape (batch_size, sequence_length, v_hidden_size)
# Q, K and V without packing:
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
# Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
# Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)
# Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
# Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
# Packed KV:
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
# Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size)
@ -2080,29 +2110,65 @@ class SymbolicShapeInference:
# Input 2 nullptr
query_shape = self._get_shape(node, 0)
total_sequence_length = None
output_dtype = None
if query_shape is not None:
if len(query_shape) == 3:
key_shape = self._get_shape(node, 1)
key_shape = self._try_get_shape(node, 1)
# By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
output_shape = query_shape
if key_shape and len(key_shape) == 3:
value_shape = self._get_shape(node, 2)
if value_shape and len(value_shape) == 3:
if key_shape is not None and len(key_shape) == 3:
value_shape = self._try_get_shape(node, 2)
if value_shape is not None and len(value_shape) == 3:
output_shape[2] = value_shape[2]
total_sequence_length = key_shape[1]
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
elif len(query_shape) == 5:
if isinstance(query_shape[2], int) and isinstance(query_shape[4], int):
output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]]
else:
output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"]
total_sequence_length = query_shape[1]
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
if len(node.output) > 1:
batch_size = query_shape[0]
num_heads = get_attribute(node, "num_heads")
head_size = None
if len(query_shape) == 3:
head_size = (
int(query_shape[2] / num_heads)
if isinstance(query_shape[2], int)
else f"{query_shape[2]}/{num_heads}"
)
else:
head_size = query_shape[4]
past_shape = self._try_get_shape(node, 6)
if past_shape is not None:
if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int):
total_sequence_length = past_shape[2] + total_sequence_length
else:
total_sequence_length = f"{past_shape[2]}+{total_sequence_length}"
present_shape = [batch_size, num_heads, total_sequence_length, head_size]
assert output_dtype is not None
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
vi = self.known_vi_[node.output[2]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
def _infer_FastGelu(self, node):
self._propagate_shape_and_type(node)
@ -2140,8 +2206,6 @@ class SymbolicShapeInference:
def _infer_SkipLayerNormalization(self, node):
self._propagate_shape_and_type(node)
if len(node.output) > 3:
self._propagate_shape_and_type(node, 0, 3)
# If the SkipLayerNormalization node contains the optional
# output for inference, infer the shape and type for it too
@ -2348,7 +2412,9 @@ class SymbolicShapeInference:
for i_o in range(len(node.output)):
# Special case: We do not care about the training related
# outputs of SkipLayerNormalization
if node.op_type == "SkipLayerNormalization" and i_o in [1, 2]:
if (
node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization"
) and i_o in [1, 2]:
continue
vi = self.known_vi_[node.output[i_o]]

View file

@ -706,13 +706,12 @@ def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
input_count = len(graph.input)
layer_count = (input_count - 3) // 4
layer_count = (input_count - 2) // 4
assert layer_count >= 1
# Expect inputs:
# input_ids: int32 (B, 1)
# encoder_attention_mask: int32 (B, encode_sequence_length)
# encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
# past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
# past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
@ -723,7 +722,7 @@ def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
# ... (for each cross attention layer)
# TODO: encoder_hidden_states is optional
expected_inputs = ["input_ids", "encoder_attention_mask", "encoder_hidden_states"]
expected_inputs = ["input_ids", "encoder_attention_mask"]
for i in range(layer_count):
expected_inputs.append(f"past_key_self_{i}")
expected_inputs.append(f"past_value_self_{i}")
@ -1406,6 +1405,43 @@ def generate_gpt2_init_decoder(
return True
def make_dim_proto_numeric_t5(model, config):
"""Make dim_proto numeric.
Args:
model: T5 encoder and decoder model.
config: T5 config.
"""
sequence_length = str(1)
num_heads = str(config.num_heads)
hidden_size = str(config.d_model)
head_size = str(config.d_kv)
for tensor in model.graph.output:
for dim_proto in tensor.type.tensor_type.shape.dim:
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
sequence_length,
num_heads,
hidden_size,
head_size,
]:
dim_value = int(dim_proto.dim_param)
dim_proto.Clear()
dim_proto.dim_value = dim_value
for tensor in model.graph.input:
for dim_proto in tensor.type.tensor_type.shape.dim:
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
sequence_length,
num_heads,
hidden_size,
head_size,
]:
dim_value = int(dim_proto.dim_param)
dim_proto.Clear()
dim_proto.dim_value = dim_value
def convert_generation_model(args: argparse.Namespace, generation_type: GenerationType = GenerationType.BEAMSEARCH):
"""Convert model according to command line arguments.
@ -1686,6 +1722,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
# )
# initializers.extend(moved_initializers)
make_dim_proto_numeric_t5(encoder_model, config)
make_dim_proto_numeric_t5(decoder_model, config)
node.attribute.extend(
[
onnx.helper.make_attribute("encoder", encoder_model.graph),

View file

@ -2,11 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from enum import Enum
from logging import getLogger
from os import name
from sys import path
from typing import Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
from fusion_base import Fusion
@ -14,7 +11,6 @@ from fusion_options import AttentionMaskFormat
from fusion_utils import FusionUtils, NumpyHelper
from onnx import NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
from shape_infer_helper import SymbolicShapeInferenceHelper, get_shape_from_type_proto
logger = getLogger(__name__)
@ -94,9 +90,10 @@ class FusionAttention(Fusion):
num_heads: int,
attention_mask: AttentionMask,
use_multi_head_attention: bool = False,
search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"],
):
attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
super().__init__(model, attention_op_name, ["SkipLayerNormalization", "LayerNormalization"])
super().__init__(model, attention_op_name, search_op_types)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_mask = attention_mask
@ -211,6 +208,7 @@ class FusionAttention(Fusion):
input: str,
output: str,
add_qk_str: str,
scale: Optional[float] = None,
) -> Union[NodeProto, None]:
"""Create an Attention node.
@ -236,12 +234,22 @@ class FusionAttention(Fusion):
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
return None
has_bias = True
if q_add is None and k_add is None and v_add is None:
has_bias = False
q_weight = self.model.get_initializer(q_matmul.input[1])
k_weight = self.model.get_initializer(k_matmul.input[1])
v_weight = self.model.get_initializer(v_matmul.input[1])
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
q_bias, k_bias, v_bias = None, None, None
if has_bias:
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
if not (k_weight and v_weight and q_bias and k_bias):
return None
if q_weight is None:
print(
@ -249,8 +257,6 @@ class FusionAttention(Fusion):
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
)
return None
if not (k_weight and v_weight and q_bias and k_bias):
return None
qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
@ -290,24 +296,25 @@ class FusionAttention(Fusion):
qkv_weight = np.stack((qw, kw, vw), axis=1)
qkv_weight_dim = 3 * qw_out_size
qb = NumpyHelper.to_array(q_bias)
kb = NumpyHelper.to_array(k_bias)
vb = NumpyHelper.to_array(v_bias)
if has_bias:
qb = NumpyHelper.to_array(q_bias)
kb = NumpyHelper.to_array(k_bias)
vb = NumpyHelper.to_array(v_bias)
q_bias_shape = np.prod(qb.shape)
k_bias_shape = np.prod(kb.shape)
v_bias_shape = np.prod(vb.shape)
q_bias_shape = np.prod(qb.shape)
k_bias_shape = np.prod(kb.shape)
v_bias_shape = np.prod(vb.shape)
assert q_bias_shape == k_bias_shape == qw_out_size
assert v_bias_shape == vw_out_size
assert q_bias_shape == k_bias_shape == qw_out_size
assert v_bias_shape == vw_out_size
qkv_bias_dim = 0
if is_qkv_diff_dims:
qkv_bias = np.concatenate((qb, kb, vb), axis=0)
qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
else:
qkv_bias = np.stack((qb, kb, vb), axis=0)
qkv_bias_dim = 3 * q_bias_shape
qkv_bias_dim = 0
if is_qkv_diff_dims:
qkv_bias = np.concatenate((qb, kb, vb), axis=0)
qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
else:
qkv_bias = np.stack((qb, kb, vb), axis=0)
qkv_bias_dim = 3 * q_bias_shape
attention_node_name = self.model.create_node_name("Attention")
@ -324,15 +331,17 @@ class FusionAttention(Fusion):
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
self.model.add_initializer(weight, self.this_graph_name)
bias = helper.make_tensor(
name=attention_node_name + "_qkv_bias",
data_type=TensorProto.FLOAT,
dims=[qkv_bias_dim],
vals=qkv_bias.flatten().tolist(),
)
if q_bias.data_type == 10:
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
self.model.add_initializer(bias, self.this_graph_name)
bias = None
if has_bias:
bias = helper.make_tensor(
name=attention_node_name + "_qkv_bias",
data_type=TensorProto.FLOAT,
dims=[qkv_bias_dim],
vals=qkv_bias.flatten().tolist(),
)
if q_bias.data_type == 10:
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
self.model.add_initializer(bias, self.this_graph_name)
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
if self.use_multi_head_attention:
@ -359,7 +368,7 @@ class FusionAttention(Fusion):
attention_inputs = [
input,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias",
attention_node_name + "_qkv_bias" if has_bias else "",
]
if mask_index is not None:
attention_inputs.append(mask_index)
@ -379,6 +388,9 @@ class FusionAttention(Fusion):
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
if scale is not None:
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
if is_qkv_diff_dims:
attention_node.attribute.extend(
[helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]

View file

@ -92,14 +92,16 @@ class T5Decoder(torch.nn.Module):
self.lm_head = lm_head
self.config = config
def forward(self, decoder_input_ids, encoder_attention_mask, encoder_hidden_states, *past):
def forward(self, decoder_input_ids, encoder_attention_mask, *past):
past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers)
# This is a hack since only the third dimension of encoder_hidden_states is used here
dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states=dummy_encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
return_dict=True,
@ -122,12 +124,10 @@ class T5DecoderInputs:
self,
decoder_input_ids,
encoder_attention_mask,
encoder_hidden_states,
past_key_values=None,
):
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
self.encoder_hidden_states: Union[torch.FloatTensor, torch.HalfTensor] = encoder_hidden_states
self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values
@staticmethod
@ -181,13 +181,6 @@ class T5DecoderInputs:
)
float_type = torch.float16 if float16 else torch.float32
encoder_hidden_state = torch.rand(
batch_size,
encode_sequence_length,
hidden_size,
dtype=float_type,
device=device,
)
if past_decode_sequence_length > 0:
self_attention_past_shape = [
@ -212,25 +205,22 @@ class T5DecoderInputs:
else:
past = None
return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, encoder_hidden_state, past)
return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past)
def to_list(self) -> List:
input_list = [
self.decoder_input_ids,
self.encoder_attention_mask,
self.encoder_hidden_states,
]
if self.past_key_values:
input_list.extend(self.past_key_values)
return input_list
def to_fp32(self):
encoder_hidden_state = self.encoder_hidden_states.to(dtype=torch.float32)
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
return T5DecoderInputs(
self.decoder_input_ids.clone(),
self.encoder_attention_mask.clone(),
encoder_hidden_state,
past,
)
@ -278,7 +268,6 @@ class T5DecoderHelper:
# Shape of input tensors (sequence_length==1):
# input_ids: (batch_size, sequence_length)
# encoder_attention_mask: (batch_size, encode_sequence_length)
# encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
@ -289,7 +278,6 @@ class T5DecoderHelper:
input_names = ["input_ids"]
input_names.append("encoder_attention_mask")
input_names.append("encoder_hidden_states")
input_names.extend(input_past_names)
dynamic_axes = {
@ -362,7 +350,6 @@ class T5DecoderHelper:
ort_inputs = {
"input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
"encoder_hidden_states": numpy.ascontiguousarray(inputs.encoder_hidden_states.cpu().numpy()),
}
if inputs.past_key_values:
@ -384,7 +371,7 @@ class T5DecoderHelper:
max_cases: int = 4,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
float16: bool = TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)"
float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
test_cases_max_diff = []

View file

@ -151,21 +151,17 @@ class T5Helper:
def auto_mixed_precision(
onnx_model: OnnxModel,
op_block_list: List[str] = [
"Pow",
"ReduceMean",
"Add",
"Sqrt",
"Div",
"Mul",
"Softmax",
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"Relu",
"Add",
],
):
"""Convert model to mixed precision.
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
Args:
onnx_model (OnnxModel): optimized ONNX model
op_block_list (List[str], optional): . Defaults to ["Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Softmax", "Relu"]
op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
Returns:
parameters(dict): a dictionary of parameters used in float16 conversion
"""
@ -235,8 +231,7 @@ class T5Helper:
from fusion_options import FusionOptions
optimization_options = None
if not use_gpu:
# Currently there is no SkipSimplifiedLayerNorm cpu kernel
if is_float16:
optimization_options = FusionOptions("t5")
optimization_options.enable_skip_layer_norm = False
@ -245,10 +240,12 @@ class T5Helper:
model_type="t5",
num_heads=num_attention_heads,
hidden_size=hidden_size,
opt_level=2 if not is_float16 and not use_external_data_format else 0,
opt_level=2 if not use_external_data_format else 0,
optimization_options=optimization_options,
use_gpu=False,
only_onnxruntime=not use_gpu,
)
if is_float16:
if auto_mixed_precision:
T5Helper.auto_mixed_precision(m)

View file

@ -615,7 +615,7 @@ class OnnxModel:
if use_symbolic_shape_infer:
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
# are not recognized by onnx shape inference.
shape_infer_helper = SymbolicShapeInferenceHelper(model)
shape_infer_helper = SymbolicShapeInferenceHelper(model, verbose=0)
model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False)
parameters = {"disable_shape_infer": use_symbolic_shape_infer}
@ -876,66 +876,64 @@ class OnnxModel:
return True
@staticmethod
def graph_topological_sort(graph):
deps_count = [0] * len(graph.node) # dependency count of each node
deps_to_nodes = {} # input to node indice
def graph_topological_sort(graph, is_deterministic=False):
deps_set = set() # dependency set of all node
sorted_node_set = set() # sorted node set
sorted_nodes = [] # initialize sorted_nodes
for node_idx, node in enumerate(graph.node):
# CANNOT use len(node.input) directly because input can be optional
deps_count[node_idx] = sum(1 for _ in node.input if _)
if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
sorted_nodes.append(graph.node[node_idx])
continue
for input_name in node.input:
if input_name not in deps_to_nodes:
deps_to_nodes[input_name] = [node_idx]
else:
deps_to_nodes[input_name].append(node_idx)
# Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph
initializer_names = [init.name for init in graph.initializer]
graph_input_names = [input.name for input in graph.input]
input_names = initializer_names + graph_input_names
input_names.sort()
prev_input_name = None
if is_deterministic:
input_names.sort()
for input_name in input_names:
if prev_input_name == input_name:
continue
deps_set.add(input_name)
prev_input_name = input_name
if input_name in deps_to_nodes:
for node_idx in deps_to_nodes[input_name]:
deps_count[node_idx] = deps_count[node_idx] - 1
if deps_count[node_idx] == 0:
sorted_nodes.append(graph.node[node_idx])
sorted_node_set_len = -1
graph_nodes = graph.node if not is_deterministic else sorted(graph.node, key=lambda x: x.name)
last_node_name = None
while len(sorted_node_set) != len(graph_nodes):
if len(sorted_node_set) == sorted_node_set_len:
break
sorted_node_set_len = len(sorted_node_set)
for node_idx, node in enumerate(graph_nodes):
if node_idx in sorted_node_set:
continue
input_count = sum(1 for _ in node.input if _)
if input_count == 0:
sorted_nodes.append(node)
sorted_node_set.add(node_idx)
for output in node.output:
deps_set.add(output)
continue
failed = False
for input_name in node.input:
if input_name != "" and input_name not in deps_set:
failed = True
last_node_name = node.name
if not failed:
sorted_nodes.append(node)
sorted_node_set.add(node_idx)
for output in node.output:
deps_set.add(output)
else:
continue
start = 0
end = len(sorted_nodes)
while start < end:
for output in sorted_nodes[start].output:
if output in deps_to_nodes:
for node_idx in deps_to_nodes[output]:
deps_count[node_idx] = deps_count[node_idx] - 1
if deps_count[node_idx] == 0:
sorted_nodes.append(graph.node[node_idx])
end = end + 1
start = start + 1
if end != len(graph.node):
if len(sorted_node_set) != len(graph.node):
raise RuntimeError(
f"Graph is not a DAG: end={end}, len(graph.node)={len(graph.node)}, graph.node[end]={graph.node[end]}"
f"Graph is not a DAG: len(sorted_node_set)={len(sorted_node_set)}, len(graph.node)={len(graph.node)}, failed at node {last_node_name}"
)
graph.ClearField("node")
graph.node.extend(sorted_nodes)
def topological_sort(self):
def topological_sort(self, is_deterministic=False):
# TODO: support graph_topological_sort() in subgraphs
# for graph in self.graphs():
# self.graph_topological_sort(graph)
OnnxModel.graph_topological_sort(self.model.graph)
OnnxModel.graph_topological_sort(self.model.graph, is_deterministic)
@staticmethod
def save(

View file

@ -3,7 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Union
from typing import Dict, Union
import numpy as np
from fusion_attention import AttentionMask, FusionAttention
@ -16,7 +16,7 @@ from onnx_model_bert import BertOnnxModel
logger = logging.getLogger(__name__)
# TODO: Support decoder self/cross attention fusion and encoder self attention fusion
class FusionT5Attention(FusionAttention):
"""
Fuse T5 Attention subgraph into one Attention node.
@ -29,25 +29,460 @@ class FusionT5Attention(FusionAttention):
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)
super().__init__(
model,
hidden_size,
num_heads,
attention_mask,
use_multi_head_attention=False,
search_op_types=["SkipSimplifiedLayerNormalization", "Add"],
)
self.static_kv = 1
def create_attention_node(
def create_mha_node(
self,
query: str,
key: str,
value: str,
mask_index: str,
matmul: NodeProto,
add: NodeProto,
res_pos_bias: str,
past_key: str,
past_value: str,
output: str,
present_key: str,
present_value: str,
num_heads: int,
hidden_size: int,
input: str,
output: str,
add_qk_str: str,
) -> Union[NodeProto, None]:
# Not implemented yet
return None
assert num_heads > 0
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
return None
attention_node_name = self.model.create_node_name("MultiHeadAttention")
attention_inputs = [
query,
"" if key is None else key, # key
"" if value is None else value, # value
"", # bias
]
if mask_index is not None:
attention_inputs.append(mask_index)
else:
attention_inputs.append("")
if res_pos_bias is not None:
attention_inputs.append(res_pos_bias)
else:
attention_inputs.append("")
if past_key is not None:
assert past_value is not None
attention_inputs.append(past_key)
attention_inputs.append(past_value)
attention_outputs = [output]
if present_key is not None:
assert present_value is not None
attention_outputs.append(present_key)
attention_outputs.append(present_value)
attention_node = helper.make_node(
"MultiHeadAttention",
inputs=attention_inputs,
outputs=attention_outputs,
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
attention_node.attribute.extend([helper.make_attribute("scale", 1.0)])
if self.mask_filter_value is not None:
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
self.increase_counter("MultiHeadAttention")
return attention_node
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Not implemented yet
return
self.fuse_t5_encoder(normalize_node, input_name_to_nodes, output_name_to_node)
self.fuse_t5_decoder(normalize_node, input_name_to_nodes, output_name_to_node)
def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_node):
if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add":
return
qkv_nodes = self.model.match_parent_path(
normalize_node,
["MatMul", "Reshape", "Transpose", "MatMul"],
[1, 0, 0, 0],
)
if qkv_nodes is None:
return
_, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes
qkv_shape_nodes = self.model.match_parent_path(
reshape_qkv,
["Concat", "Unsqueeze", "Gather", "Shape"],
[1, 0, 0, 0],
)
if qkv_shape_nodes is None:
return
input_shape_node = qkv_shape_nodes[-1]
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
if v_nodes is None:
return
_, reshape_v, matmul_v = v_nodes
# todo: check reshape_v parent nodes
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Add", "MatMul"],
[0, 0, 0],
)
if qk_nodes is None:
return
_, add_qk, matmul_qk = qk_nodes
mask_index = None
mask_nodes = self.model.match_parent_path(
add_qk,
["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
[1, 1, 0, 1, 0, 0],
)
if mask_nodes is None:
return
mul_node = mask_nodes[1]
if mask_nodes[1].op_type != "Mul":
return
_, mul_val = self.model.get_constant_input(mul_node)
if mul_val != -10000:
self.mask_filter_value = mul_val
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
res_pos_bias = None
rpb_nodes = self.model.match_parent_path(
add_qk,
["Add", "RelativePositionBias"],
[1, 0],
)
if rpb_nodes is None:
return
rpb_add_node = rpb_nodes[0]
res_pos_bias = rpb_add_node.input[0]
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
if k_nodes is None:
return
_, reshape_k, matmul_k = k_nodes
# todo: check reshape_k parent nodes
q_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[0, 0, 0],
)
if q_nodes is None:
return
transpose_q, reshape_q, matmul_q = q_nodes
# todo: check reshape_q parent nodes
if matmul_q.input[0] != input_shape_node.input[0]:
return
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
new_node = self.create_attention_node(
mask_index,
matmul_q,
matmul_k,
matmul_v,
None,
None,
None,
q_num_heads,
q_hidden_size,
input_shape_node.input[0],
reshape_qkv.output[0],
res_pos_bias,
1.0,
)
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend(qkv_nodes[1:])
self.nodes_to_remove.extend(qk_nodes)
self.nodes_to_remove.extend(k_nodes[:-1])
if v_nodes is not None:
self.nodes_to_remove.extend(v_nodes[:-1])
self.nodes_to_remove.extend(q_nodes[:-1])
self.prune_graph = True
def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_node):
if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add":
return
qkv_nodes = self.model.match_parent_path(
normalize_node,
["MatMul", "Reshape", "Transpose", "MatMul"],
[1, 0, 0, 0],
)
if qkv_nodes is None:
return
_, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes
qkv_shape_nodes = self.model.match_parent_path(
reshape_qkv,
["Concat", "Unsqueeze", "Gather", "Shape"],
[1, 0, 0, 0],
)
if qkv_shape_nodes is None:
return
input_shape_node = qkv_shape_nodes[-1]
value = None
past_value = None
present_value = None
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "MatMul"],
[1, 1, 0, 0],
)
if v_nodes is None:
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
if v_nodes is not None:
transpose_v, reshape_v, matmul_v = v_nodes
value = reshape_v.input[0]
present_value = transpose_v.output[0]
if "present_value" not in present_value:
return
if matmul_v.input[0] != input_shape_node.input[0]:
self.static_kv = 1
else:
self.static_kv = 0
else:
past_value = matmul_qkv.input[1]
if past_value in output_name_to_node:
return
if "past_value_cross" not in past_value:
return
self.static_kv = 1
else:
concat_v, _, reshape_v, _ = v_nodes
past_value = concat_v.input[0]
if past_value in output_name_to_node:
return
if "past_value_self" not in past_value:
return
present_value = concat_v.output[0]
if "present_value_self" not in present_value:
return
value = reshape_v.input[0]
self.static_kv = 0
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Add", "MatMul"],
[0, 0, 0],
)
if qk_nodes is None:
return
_, add_qk, matmul_qk = qk_nodes
mask_index = None
res_pos_bias = None
if self.static_kv == 1:
mask_nodes = self.model.match_parent_path(
add_qk,
["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
[1, 1, 0, 1, 0, 0],
)
if mask_nodes is None:
return
mul_node = mask_nodes[1]
if mask_nodes[1].op_type != "Mul":
return
_, mul_val = self.model.get_constant_input(mul_node)
if mul_val != -10000:
self.mask_filter_value = mul_val
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
else:
rpb_nodes = self.model.match_parent_path(
add_qk,
["Add", "Slice"],
[1, 0],
)
if rpb_nodes is not None:
res_pos_bias = add_qk.input[1]
else:
rpb_nodes = self.model.match_parent_path(
add_qk,
["Add", "RelativePositionBias"],
[1, 0],
)
if rpb_nodes is None:
return
res_pos_bias = add_qk.input[1]
key = None
past_key = None
present_key = None
if self.static_kv == 1:
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
if k_nodes is not None:
transpose_k, reshape_k, _ = k_nodes
key = reshape_k.input[0]
present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
for present_key_transpose_node in present_key_transpose_nodes:
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
if present_key_candidate is not None:
present_key = present_key_candidate.name
break
if present_key is None:
return
if "present_key_cross" not in present_key:
return
else:
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose"],
[1],
)
if k_nodes is None:
return
transpose_k = k_nodes[0]
past_key = transpose_k.input[0]
if past_key in output_name_to_node:
return
if "past_key_cross" not in past_key:
return
else:
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "Reshape", "MatMul"],
[1, 0, 1, 0],
)
if k_nodes is not None:
_, concat_k, reshape_k, _ = k_nodes
key = reshape_k.input[0]
past_key_transpose_node = output_name_to_node[concat_k.input[0]]
past_key = past_key_transpose_node.input[0]
if past_key in output_name_to_node:
return
if "past_key_self" not in past_key:
return
present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
for present_key_transpose_node in present_key_transpose_nodes:
# print("present_key_transpose_node:", present_key_transpose_node)
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
# print("present_key_candidate:", present_key_candidate)
if present_key_candidate is not None:
present_key = present_key_candidate.name
break
if present_key is None:
return
if "present_key_self" not in present_key:
return
else:
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
if k_nodes is None:
return
_, reshape_k, _ = k_nodes
key = reshape_k.input[0]
present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
for present_key_transpose_node in present_key_transpose_nodes:
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
if present_key_candidate is not None:
present_key = present_key_candidate.name
break
if present_key is None:
return
if "present_key_self" not in present_key:
return
q_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[0, 0, 0],
)
if q_nodes is None:
return
transpose_q, reshape_q, matmul_q = q_nodes
if matmul_q.input[0] != input_shape_node.input[0]:
return
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
if self.static_kv == 1 and past_key is not None:
key = past_key
value = past_value
past_key = None
past_value = None
new_node = self.create_mha_node(
matmul_q.output[0],
key,
value,
mask_index,
res_pos_bias,
past_key,
past_value,
reshape_qkv.output[0],
present_key,
present_value,
q_num_heads,
q_hidden_size,
)
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend(qkv_nodes[1:])
self.nodes_to_remove.extend(qk_nodes)
self.nodes_to_remove.extend(k_nodes[:-1])
if v_nodes is not None:
self.nodes_to_remove.extend(v_nodes[:-1])
self.nodes_to_remove.extend(q_nodes[:-1])
self.prune_graph = True
class FusionRelativePositionBiasBlock(Fusion):
@ -135,12 +570,56 @@ class FusionRelativePositionBiasBlock(Fusion):
self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name
class FusionSimplifiedLayerNormalization(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "SimplifiedLayerNormalization", "Mul")
def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
if node.op_type != "Mul":
return
sim_ln_nodes = self.model.match_parent_path(
node,
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
[1, 1, 1, 0, 0, 0, 0],
)
if sim_ln_nodes is None:
return
pow_node = sim_ln_nodes[-2]
if not self.model.find_constant_input(pow_node, 2.0) == 1:
return
root_input = pow_node.input[0]
mul_node_1 = sim_ln_nodes[0]
if root_input != mul_node_1.input[0]:
return
second_add_node = sim_ln_nodes[3]
i, add_weight = self.model.get_constant_input(second_add_node)
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
logger.warning(f"epsilon value is not expeced: {add_weight}")
return
self.nodes_to_remove.extend(sim_ln_nodes[:-1])
normalize_node = helper.make_node(
"SimplifiedLayerNormalization",
inputs=[root_input, node.input[0]],
outputs=[node.output[0]],
name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"),
)
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
normalize_node.attribute.extend([helper.make_attribute("axis", int(-1))])
normalize_node.attribute.extend([helper.make_attribute("stash_type", int(1))])
self.nodes_to_add.append(normalize_node)
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
def __init__(self, model: OnnxModel):
super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
self.shape_infer_helper = self.model.infer_runtime_shape(
{"batch_size": 2, "seq_len": 1, "encode_sequence_length": 8, "past_decode_sequence_length": 4}, update=True
)
def fuse(self, node, input_name_to_nodes, output_name_to_node):
super().fuse(node, input_name_to_nodes, output_name_to_node)
@ -151,6 +630,7 @@ class T5OnnxModel(BertOnnxModel):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self)
self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
# TODO: consider retrive max_distance from model.
# math.log(max_distance / (num_buckets // 2))
@ -159,6 +639,9 @@ class T5OnnxModel(BertOnnxModel):
def fuse_attention(self):
self.attention_fusion.apply()
def fuse_layer_norm(self):
self.layer_norm_fusion.apply()
def fuse_skip_layer_norm(self):
self.skip_layer_norm_fusion.apply()
@ -234,8 +717,11 @@ class T5OnnxModel(BertOnnxModel):
nodes_to_remove.append(node)
self.remove_nodes(nodes_to_remove)
def postprocess(self):
def preprocess(self):
self.adjust_reshape_and_expand()
self.rpb_fusion.apply()
def postprocess(self):
# remove get_extended_attention_mask() since it generates all zeros.
self.remove_extended_mask_decoder_init()
self.remove_extended_mask_decoder()

View file

@ -25,11 +25,11 @@ else:
class TestFusion(unittest.TestCase):
def verify_fusion(self, optimized_model, expected_model_filename):
optimized_model.topological_sort()
optimized_model.topological_sort(is_deterministic=True)
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename)
expected_model = OnnxModel(onnx.load(expected_model_path))
expected_model.topological_sort()
expected_model.topological_sort(is_deterministic=True)
self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))