diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index a20757d800..581a61ef3f 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -163,6 +163,7 @@ class SymbolicShapeInference: "Reciprocal": self._pass_on_shape_and_type, "ReduceSum": self._infer_ReduceSum, "ReduceProd": self._infer_ReduceProd, + "RelativePositionBias": self._infer_RelativePositionBias, "Reshape": self._infer_Reshape, "Resize": self._infer_Resize, "Round": self._pass_on_shape_and_type, @@ -378,6 +379,17 @@ class SymbolicShapeInference: assert name in self.initializers_ return list(self.initializers_[name].dims) + def _try_get_shape(self, node, idx): + if idx > len(node.input) - 1: + return None + name = node.input[idx] + if name in self.known_vi_: + vi = self.known_vi_[name] + return get_shape_from_value_info(vi) + if name in self.initializers_: + return list(self.initializers_[name].dims) + return None + def _get_shape_rank(self, node, idx): return len(self._get_shape(node, idx)) @@ -437,6 +449,7 @@ class SymbolicShapeInference: "GemmFastGelu", "LayerNormalization", "LongformerAttention", + "RelativePositionBias", "SimplifiedLayerNormalization", "SkipLayerNormalization", "SkipSimplifiedLayerNormalization", @@ -1495,6 +1508,19 @@ class SymbolicShapeInference: if data is not None: self.sympy_data_[node.output[0]] = sympy_reduce_product(data) + def _infer_RelativePositionBias(self, node): + seq_len = self._try_get_value(node, 1) + real_seq_len = self._try_get_value(node, 2) + if seq_len is None or real_seq_len is None: + return + num_heads = self._get_sympy_shape(node, 0)[1] + + new_shape = [1, num_heads, str(seq_len), str(real_seq_len)] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + def _infer_Reshape(self, node): shape_value = self._try_get_value(node, 1) vi = self.known_vi_[node.output[0]] @@ -2030,14 +2056,18 @@ class SymbolicShapeInference: def _infer_Attention(self, node): shape = self._get_shape(node, 0) - shape_bias = self._get_shape(node, 2) - if shape and len(shape) == 3 and shape_bias and len(shape_bias) == 1: + shape_weights = self._get_shape(node, 1) + shape_bias = self._try_get_shape(node, 2) + if shape_bias is not None: + assert len(shape_bias) == 1 + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] + if shape and len(shape) == 3: qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") if qkv_hidden_sizes_attr is not None: assert len(qkv_hidden_sizes_attr) == 3 shape[2] = int(qkv_hidden_sizes_attr[2]) - elif isinstance(shape_bias[0], int): - shape[2] = int(shape_bias[0] / 3) + elif isinstance(tripled_hidden_size, int): + shape[2] = int(tripled_hidden_size / 3) output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) @@ -2068,8 +2098,8 @@ class SymbolicShapeInference: # Output 0 has shape (batch_size, sequence_length, v_hidden_size) # Q, K and V without packing: # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) - # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) + # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) + # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) # Packed KV: # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) # Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size) @@ -2080,29 +2110,65 @@ class SymbolicShapeInference: # Input 2 nullptr query_shape = self._get_shape(node, 0) + total_sequence_length = None + output_dtype = None if query_shape is not None: if len(query_shape) == 3: - key_shape = self._get_shape(node, 1) + key_shape = self._try_get_shape(node, 1) # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. output_shape = query_shape - if key_shape and len(key_shape) == 3: - value_shape = self._get_shape(node, 2) - if value_shape and len(value_shape) == 3: + if key_shape is not None and len(key_shape) == 3: + value_shape = self._try_get_shape(node, 2) + if value_shape is not None and len(value_shape) == 3: output_shape[2] = value_shape[2] + total_sequence_length = key_shape[1] output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + elif len(query_shape) == 5: if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]] else: output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"] + total_sequence_length = query_shape[1] + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + if len(node.output) > 1: + batch_size = query_shape[0] + num_heads = get_attribute(node, "num_heads") + + head_size = None + if len(query_shape) == 3: + head_size = ( + int(query_shape[2] / num_heads) + if isinstance(query_shape[2], int) + else f"{query_shape[2]}/{num_heads}" + ) + else: + head_size = query_shape[4] + + past_shape = self._try_get_shape(node, 6) + + if past_shape is not None: + if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int): + total_sequence_length = past_shape[2] + total_sequence_length + else: + total_sequence_length = f"{past_shape[2]}+{total_sequence_length}" + + present_shape = [batch_size, num_heads, total_sequence_length, head_size] + + assert output_dtype is not None + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + def _infer_FastGelu(self, node): self._propagate_shape_and_type(node) @@ -2140,8 +2206,6 @@ class SymbolicShapeInference: def _infer_SkipLayerNormalization(self, node): self._propagate_shape_and_type(node) - if len(node.output) > 3: - self._propagate_shape_and_type(node, 0, 3) # If the SkipLayerNormalization node contains the optional # output for inference, infer the shape and type for it too @@ -2348,7 +2412,9 @@ class SymbolicShapeInference: for i_o in range(len(node.output)): # Special case: We do not care about the training related # outputs of SkipLayerNormalization - if node.op_type == "SkipLayerNormalization" and i_o in [1, 2]: + if ( + node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" + ) and i_o in [1, 2]: continue vi = self.known_vi_[node.output[i_o]] diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 3e2c1bdf08..37855c8c02 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -706,13 +706,12 @@ def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision): float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT input_count = len(graph.input) - layer_count = (input_count - 3) // 4 + layer_count = (input_count - 2) // 4 assert layer_count >= 1 # Expect inputs: # input_ids: int32 (B, 1) # encoder_attention_mask: int32 (B, encode_sequence_length) - # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) # past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) # past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) @@ -723,7 +722,7 @@ def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision): # ... (for each cross attention layer) # TODO: encoder_hidden_states is optional - expected_inputs = ["input_ids", "encoder_attention_mask", "encoder_hidden_states"] + expected_inputs = ["input_ids", "encoder_attention_mask"] for i in range(layer_count): expected_inputs.append(f"past_key_self_{i}") expected_inputs.append(f"past_value_self_{i}") @@ -1406,6 +1405,43 @@ def generate_gpt2_init_decoder( return True +def make_dim_proto_numeric_t5(model, config): + """Make dim_proto numeric. + + Args: + model: T5 encoder and decoder model. + config: T5 config. + """ + sequence_length = str(1) + num_heads = str(config.num_heads) + hidden_size = str(config.d_model) + head_size = str(config.d_kv) + + for tensor in model.graph.output: + for dim_proto in tensor.type.tensor_type.shape.dim: + if dim_proto.HasField("dim_param") and dim_proto.dim_param in [ + sequence_length, + num_heads, + hidden_size, + head_size, + ]: + dim_value = int(dim_proto.dim_param) + dim_proto.Clear() + dim_proto.dim_value = dim_value + + for tensor in model.graph.input: + for dim_proto in tensor.type.tensor_type.shape.dim: + if dim_proto.HasField("dim_param") and dim_proto.dim_param in [ + sequence_length, + num_heads, + hidden_size, + head_size, + ]: + dim_value = int(dim_proto.dim_param) + dim_proto.Clear() + dim_proto.dim_value = dim_value + + def convert_generation_model(args: argparse.Namespace, generation_type: GenerationType = GenerationType.BEAMSEARCH): """Convert model according to command line arguments. @@ -1686,6 +1722,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati # ) # initializers.extend(moved_initializers) + make_dim_proto_numeric_t5(encoder_model, config) + make_dim_proto_numeric_t5(decoder_model, config) + node.attribute.extend( [ onnx.helper.make_attribute("encoder", encoder_model.graph), diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 342d43306e..65ba83b8d9 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -2,11 +2,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from enum import Enum from logging import getLogger -from os import name -from sys import path -from typing import Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np from fusion_base import Fusion @@ -14,7 +11,6 @@ from fusion_options import AttentionMaskFormat from fusion_utils import FusionUtils, NumpyHelper from onnx import NodeProto, TensorProto, helper, numpy_helper from onnx_model import OnnxModel -from shape_infer_helper import SymbolicShapeInferenceHelper, get_shape_from_type_proto logger = getLogger(__name__) @@ -94,9 +90,10 @@ class FusionAttention(Fusion): num_heads: int, attention_mask: AttentionMask, use_multi_head_attention: bool = False, + search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"], ): attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention" - super().__init__(model, attention_op_name, ["SkipLayerNormalization", "LayerNormalization"]) + super().__init__(model, attention_op_name, search_op_types) self.hidden_size = hidden_size self.num_heads = num_heads self.attention_mask = attention_mask @@ -211,6 +208,7 @@ class FusionAttention(Fusion): input: str, output: str, add_qk_str: str, + scale: Optional[float] = None, ) -> Union[NodeProto, None]: """Create an Attention node. @@ -236,12 +234,22 @@ class FusionAttention(Fusion): logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") return None + has_bias = True + if q_add is None and k_add is None and v_add is None: + has_bias = False + q_weight = self.model.get_initializer(q_matmul.input[1]) k_weight = self.model.get_initializer(k_matmul.input[1]) v_weight = self.model.get_initializer(v_matmul.input[1]) - q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0]) - k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0]) - v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0]) + + q_bias, k_bias, v_bias = None, None, None + if has_bias: + q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0]) + k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0]) + v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0]) + + if not (k_weight and v_weight and q_bias and k_bias): + return None if q_weight is None: print( @@ -249,8 +257,6 @@ class FusionAttention(Fusion): "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion" ) return None - if not (k_weight and v_weight and q_bias and k_bias): - return None qw = NumpyHelper.to_array(q_weight) kw = NumpyHelper.to_array(k_weight) @@ -290,24 +296,25 @@ class FusionAttention(Fusion): qkv_weight = np.stack((qw, kw, vw), axis=1) qkv_weight_dim = 3 * qw_out_size - qb = NumpyHelper.to_array(q_bias) - kb = NumpyHelper.to_array(k_bias) - vb = NumpyHelper.to_array(v_bias) + if has_bias: + qb = NumpyHelper.to_array(q_bias) + kb = NumpyHelper.to_array(k_bias) + vb = NumpyHelper.to_array(v_bias) - q_bias_shape = np.prod(qb.shape) - k_bias_shape = np.prod(kb.shape) - v_bias_shape = np.prod(vb.shape) + q_bias_shape = np.prod(qb.shape) + k_bias_shape = np.prod(kb.shape) + v_bias_shape = np.prod(vb.shape) - assert q_bias_shape == k_bias_shape == qw_out_size - assert v_bias_shape == vw_out_size + assert q_bias_shape == k_bias_shape == qw_out_size + assert v_bias_shape == vw_out_size - qkv_bias_dim = 0 - if is_qkv_diff_dims: - qkv_bias = np.concatenate((qb, kb, vb), axis=0) - qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape - else: - qkv_bias = np.stack((qb, kb, vb), axis=0) - qkv_bias_dim = 3 * q_bias_shape + qkv_bias_dim = 0 + if is_qkv_diff_dims: + qkv_bias = np.concatenate((qb, kb, vb), axis=0) + qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape + else: + qkv_bias = np.stack((qb, kb, vb), axis=0) + qkv_bias_dim = 3 * q_bias_shape attention_node_name = self.model.create_node_name("Attention") @@ -324,15 +331,17 @@ class FusionAttention(Fusion): weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) self.model.add_initializer(weight, self.this_graph_name) - bias = helper.make_tensor( - name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, - dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), - ) - if q_bias.data_type == 10: - bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) - self.model.add_initializer(bias, self.this_graph_name) + bias = None + if has_bias: + bias = helper.make_tensor( + name=attention_node_name + "_qkv_bias", + data_type=TensorProto.FLOAT, + dims=[qkv_bias_dim], + vals=qkv_bias.flatten().tolist(), + ) + if q_bias.data_type == 10: + bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) + self.model.add_initializer(bias, self.this_graph_name) # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights. if self.use_multi_head_attention: @@ -359,7 +368,7 @@ class FusionAttention(Fusion): attention_inputs = [ input, attention_node_name + "_qkv_weight", - attention_node_name + "_qkv_bias", + attention_node_name + "_qkv_bias" if has_bias else "", ] if mask_index is not None: attention_inputs.append(mask_index) @@ -379,6 +388,9 @@ class FusionAttention(Fusion): attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + if scale is not None: + attention_node.attribute.extend([helper.make_attribute("scale", scale)]) + if is_qkv_diff_dims: attention_node.attribute.extend( [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])] diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py index 8ff5b23cef..bfa14d67c3 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py @@ -92,14 +92,16 @@ class T5Decoder(torch.nn.Module): self.lm_head = lm_head self.config = config - def forward(self, decoder_input_ids, encoder_attention_mask, encoder_hidden_states, *past): + def forward(self, decoder_input_ids, encoder_attention_mask, *past): past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers) + # This is a hack since only the third dimension of encoder_hidden_states is used here + dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2) decoder_outputs = self.decoder( input_ids=decoder_input_ids, past_key_values=past_key_values, - encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states=dummy_encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=True, return_dict=True, @@ -122,12 +124,10 @@ class T5DecoderInputs: self, decoder_input_ids, encoder_attention_mask, - encoder_hidden_states, past_key_values=None, ): self.decoder_input_ids: torch.LongTensor = decoder_input_ids self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask - self.encoder_hidden_states: Union[torch.FloatTensor, torch.HalfTensor] = encoder_hidden_states self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values @staticmethod @@ -181,13 +181,6 @@ class T5DecoderInputs: ) float_type = torch.float16 if float16 else torch.float32 - encoder_hidden_state = torch.rand( - batch_size, - encode_sequence_length, - hidden_size, - dtype=float_type, - device=device, - ) if past_decode_sequence_length > 0: self_attention_past_shape = [ @@ -212,25 +205,22 @@ class T5DecoderInputs: else: past = None - return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, encoder_hidden_state, past) + return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past) def to_list(self) -> List: input_list = [ self.decoder_input_ids, self.encoder_attention_mask, - self.encoder_hidden_states, ] if self.past_key_values: input_list.extend(self.past_key_values) return input_list def to_fp32(self): - encoder_hidden_state = self.encoder_hidden_states.to(dtype=torch.float32) past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None return T5DecoderInputs( self.decoder_input_ids.clone(), self.encoder_attention_mask.clone(), - encoder_hidden_state, past, ) @@ -278,7 +268,6 @@ class T5DecoderHelper: # Shape of input tensors (sequence_length==1): # input_ids: (batch_size, sequence_length) # encoder_attention_mask: (batch_size, encode_sequence_length) - # encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size) # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size) # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size) @@ -289,7 +278,6 @@ class T5DecoderHelper: input_names = ["input_ids"] input_names.append("encoder_attention_mask") - input_names.append("encoder_hidden_states") input_names.extend(input_past_names) dynamic_axes = { @@ -362,7 +350,6 @@ class T5DecoderHelper: ort_inputs = { "input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()), "encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()), - "encoder_hidden_states": numpy.ascontiguousarray(inputs.encoder_hidden_states.cpu().numpy()), } if inputs.past_key_values: @@ -384,7 +371,7 @@ class T5DecoderHelper: max_cases: int = 4, ): """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" - float16: bool = TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)" + float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)" test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)] test_cases_max_diff = [] diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py index c91c0da178..17ea255386 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py @@ -151,21 +151,17 @@ class T5Helper: def auto_mixed_precision( onnx_model: OnnxModel, op_block_list: List[str] = [ - "Pow", - "ReduceMean", - "Add", - "Sqrt", - "Div", - "Mul", - "Softmax", + "SimplifiedLayerNormalization", + "SkipSimplifiedLayerNormalization", "Relu", + "Add", ], ): """Convert model to mixed precision. It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically. Args: onnx_model (OnnxModel): optimized ONNX model - op_block_list (List[str], optional): . Defaults to ["Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Softmax", "Relu"] + op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"] Returns: parameters(dict): a dictionary of parameters used in float16 conversion """ @@ -235,8 +231,7 @@ class T5Helper: from fusion_options import FusionOptions optimization_options = None - if not use_gpu: - # Currently there is no SkipSimplifiedLayerNorm cpu kernel + if is_float16: optimization_options = FusionOptions("t5") optimization_options.enable_skip_layer_norm = False @@ -245,10 +240,12 @@ class T5Helper: model_type="t5", num_heads=num_attention_heads, hidden_size=hidden_size, - opt_level=2 if not is_float16 and not use_external_data_format else 0, + opt_level=2 if not use_external_data_format else 0, optimization_options=optimization_options, use_gpu=False, + only_onnxruntime=not use_gpu, ) + if is_float16: if auto_mixed_precision: T5Helper.auto_mixed_precision(m) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index dd58f82171..cce2cbe5a4 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -615,7 +615,7 @@ class OnnxModel: if use_symbolic_shape_infer: # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) # are not recognized by onnx shape inference. - shape_infer_helper = SymbolicShapeInferenceHelper(model) + shape_infer_helper = SymbolicShapeInferenceHelper(model, verbose=0) model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False) parameters = {"disable_shape_infer": use_symbolic_shape_infer} @@ -876,66 +876,64 @@ class OnnxModel: return True @staticmethod - def graph_topological_sort(graph): - deps_count = [0] * len(graph.node) # dependency count of each node - deps_to_nodes = {} # input to node indice + def graph_topological_sort(graph, is_deterministic=False): + deps_set = set() # dependency set of all node + sorted_node_set = set() # sorted node set sorted_nodes = [] # initialize sorted_nodes - for node_idx, node in enumerate(graph.node): - # CANNOT use len(node.input) directly because input can be optional - deps_count[node_idx] = sum(1 for _ in node.input if _) - if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs - sorted_nodes.append(graph.node[node_idx]) - continue - for input_name in node.input: - if input_name not in deps_to_nodes: - deps_to_nodes[input_name] = [node_idx] - else: - deps_to_nodes[input_name].append(node_idx) - - # Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph initializer_names = [init.name for init in graph.initializer] graph_input_names = [input.name for input in graph.input] input_names = initializer_names + graph_input_names - input_names.sort() - prev_input_name = None + + if is_deterministic: + input_names.sort() + for input_name in input_names: - if prev_input_name == input_name: - continue + deps_set.add(input_name) - prev_input_name = input_name - if input_name in deps_to_nodes: - for node_idx in deps_to_nodes[input_name]: - deps_count[node_idx] = deps_count[node_idx] - 1 - if deps_count[node_idx] == 0: - sorted_nodes.append(graph.node[node_idx]) + sorted_node_set_len = -1 + graph_nodes = graph.node if not is_deterministic else sorted(graph.node, key=lambda x: x.name) + last_node_name = None + while len(sorted_node_set) != len(graph_nodes): + if len(sorted_node_set) == sorted_node_set_len: + break + sorted_node_set_len = len(sorted_node_set) + for node_idx, node in enumerate(graph_nodes): + if node_idx in sorted_node_set: + continue + input_count = sum(1 for _ in node.input if _) + if input_count == 0: + sorted_nodes.append(node) + sorted_node_set.add(node_idx) + for output in node.output: + deps_set.add(output) + continue + failed = False + for input_name in node.input: + if input_name != "" and input_name not in deps_set: + failed = True + last_node_name = node.name + if not failed: + sorted_nodes.append(node) + sorted_node_set.add(node_idx) + for output in node.output: + deps_set.add(output) + else: + continue - start = 0 - end = len(sorted_nodes) - - while start < end: - for output in sorted_nodes[start].output: - if output in deps_to_nodes: - for node_idx in deps_to_nodes[output]: - deps_count[node_idx] = deps_count[node_idx] - 1 - if deps_count[node_idx] == 0: - sorted_nodes.append(graph.node[node_idx]) - end = end + 1 - start = start + 1 - - if end != len(graph.node): + if len(sorted_node_set) != len(graph.node): raise RuntimeError( - f"Graph is not a DAG: end={end}, len(graph.node)={len(graph.node)}, graph.node[end]={graph.node[end]}" + f"Graph is not a DAG: len(sorted_node_set)={len(sorted_node_set)}, len(graph.node)={len(graph.node)}, failed at node {last_node_name}" ) graph.ClearField("node") graph.node.extend(sorted_nodes) - def topological_sort(self): + def topological_sort(self, is_deterministic=False): # TODO: support graph_topological_sort() in subgraphs # for graph in self.graphs(): # self.graph_topological_sort(graph) - OnnxModel.graph_topological_sort(self.model.graph) + OnnxModel.graph_topological_sort(self.model.graph, is_deterministic) @staticmethod def save( diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 02e67e9a2a..0a1c62da59 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Union +from typing import Dict, Union import numpy as np from fusion_attention import AttentionMask, FusionAttention @@ -16,7 +16,7 @@ from onnx_model_bert import BertOnnxModel logger = logging.getLogger(__name__) -# TODO: Support decoder self/cross attention fusion and encoder self attention fusion + class FusionT5Attention(FusionAttention): """ Fuse T5 Attention subgraph into one Attention node. @@ -29,25 +29,460 @@ class FusionT5Attention(FusionAttention): num_heads: int, attention_mask: AttentionMask, ): - super().__init__(model, hidden_size, num_heads, attention_mask) + super().__init__( + model, + hidden_size, + num_heads, + attention_mask, + use_multi_head_attention=False, + search_op_types=["SkipSimplifiedLayerNormalization", "Add"], + ) + self.static_kv = 1 - def create_attention_node( + def create_mha_node( self, + query: str, + key: str, + value: str, mask_index: str, - matmul: NodeProto, - add: NodeProto, + res_pos_bias: str, + past_key: str, + past_value: str, + output: str, + present_key: str, + present_value: str, num_heads: int, hidden_size: int, - input: str, - output: str, - add_qk_str: str, ) -> Union[NodeProto, None]: - # Not implemented yet - return None + + assert num_heads > 0 + + if hidden_size > 0 and (hidden_size % num_heads) != 0: + logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + return None + + attention_node_name = self.model.create_node_name("MultiHeadAttention") + attention_inputs = [ + query, + "" if key is None else key, # key + "" if value is None else value, # value + "", # bias + ] + if mask_index is not None: + attention_inputs.append(mask_index) + else: + attention_inputs.append("") + + if res_pos_bias is not None: + attention_inputs.append(res_pos_bias) + else: + attention_inputs.append("") + + if past_key is not None: + assert past_value is not None + attention_inputs.append(past_key) + attention_inputs.append(past_value) + + attention_outputs = [output] + if present_key is not None: + assert present_value is not None + attention_outputs.append(present_key) + attention_outputs.append(present_value) + + attention_node = helper.make_node( + "MultiHeadAttention", + inputs=attention_inputs, + outputs=attention_outputs, + name=attention_node_name, + ) + + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + attention_node.attribute.extend([helper.make_attribute("scale", 1.0)]) + if self.mask_filter_value is not None: + attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))]) + + self.increase_counter("MultiHeadAttention") + return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): - # Not implemented yet - return + self.fuse_t5_encoder(normalize_node, input_name_to_nodes, output_name_to_node) + self.fuse_t5_decoder(normalize_node, input_name_to_nodes, output_name_to_node) + + def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_node): + if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add": + return + + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0], + ) + if qkv_nodes is None: + return + + _, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes + + qkv_shape_nodes = self.model.match_parent_path( + reshape_qkv, + ["Concat", "Unsqueeze", "Gather", "Shape"], + [1, 0, 0, 0], + ) + if qkv_shape_nodes is None: + return + input_shape_node = qkv_shape_nodes[-1] + + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], + ) + if v_nodes is None: + return + _, reshape_v, matmul_v = v_nodes + # todo: check reshape_v parent nodes + + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "Add", "MatMul"], + [0, 0, 0], + ) + if qk_nodes is None: + return + _, add_qk, matmul_qk = qk_nodes + + mask_index = None + mask_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], + [1, 1, 0, 1, 0, 0], + ) + if mask_nodes is None: + return + mul_node = mask_nodes[1] + if mask_nodes[1].op_type != "Mul": + return + + _, mul_val = self.model.get_constant_input(mul_node) + if mul_val != -10000: + self.mask_filter_value = mul_val + + mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) + + res_pos_bias = None + rpb_nodes = self.model.match_parent_path( + add_qk, + ["Add", "RelativePositionBias"], + [1, 0], + ) + if rpb_nodes is None: + return + rpb_add_node = rpb_nodes[0] + res_pos_bias = rpb_add_node.input[0] + + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], + ) + if k_nodes is None: + return + _, reshape_k, matmul_k = k_nodes + # todo: check reshape_k parent nodes + + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "MatMul"], + [0, 0, 0], + ) + if q_nodes is None: + return + + transpose_q, reshape_q, matmul_q = q_nodes + # todo: check reshape_q parent nodes + + if matmul_q.input[0] != input_shape_node.input[0]: + return + + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + + new_node = self.create_attention_node( + mask_index, + matmul_q, + matmul_k, + matmul_v, + None, + None, + None, + q_num_heads, + q_hidden_size, + input_shape_node.input[0], + reshape_qkv.output[0], + res_pos_bias, + 1.0, + ) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend(qkv_nodes[1:]) + self.nodes_to_remove.extend(qk_nodes) + self.nodes_to_remove.extend(k_nodes[:-1]) + if v_nodes is not None: + self.nodes_to_remove.extend(v_nodes[:-1]) + self.nodes_to_remove.extend(q_nodes[:-1]) + + self.prune_graph = True + + def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_node): + if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add": + return + + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0], + ) + if qkv_nodes is None: + return + + _, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes + + qkv_shape_nodes = self.model.match_parent_path( + reshape_qkv, + ["Concat", "Unsqueeze", "Gather", "Shape"], + [1, 0, 0, 0], + ) + if qkv_shape_nodes is None: + return + input_shape_node = qkv_shape_nodes[-1] + + value = None + past_value = None + present_value = None + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0], + ) + if v_nodes is None: + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], + ) + if v_nodes is not None: + transpose_v, reshape_v, matmul_v = v_nodes + value = reshape_v.input[0] + present_value = transpose_v.output[0] + if "present_value" not in present_value: + return + if matmul_v.input[0] != input_shape_node.input[0]: + self.static_kv = 1 + else: + self.static_kv = 0 + else: + past_value = matmul_qkv.input[1] + if past_value in output_name_to_node: + return + if "past_value_cross" not in past_value: + return + self.static_kv = 1 + else: + concat_v, _, reshape_v, _ = v_nodes + past_value = concat_v.input[0] + if past_value in output_name_to_node: + return + if "past_value_self" not in past_value: + return + present_value = concat_v.output[0] + if "present_value_self" not in present_value: + return + value = reshape_v.input[0] + self.static_kv = 0 + + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "Add", "MatMul"], + [0, 0, 0], + ) + if qk_nodes is None: + return + _, add_qk, matmul_qk = qk_nodes + + mask_index = None + res_pos_bias = None + if self.static_kv == 1: + mask_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], + [1, 1, 0, 1, 0, 0], + ) + if mask_nodes is None: + return + mul_node = mask_nodes[1] + if mask_nodes[1].op_type != "Mul": + return + + _, mul_val = self.model.get_constant_input(mul_node) + if mul_val != -10000: + self.mask_filter_value = mul_val + + mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) + else: + rpb_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Slice"], + [1, 0], + ) + if rpb_nodes is not None: + res_pos_bias = add_qk.input[1] + else: + rpb_nodes = self.model.match_parent_path( + add_qk, + ["Add", "RelativePositionBias"], + [1, 0], + ) + if rpb_nodes is None: + return + res_pos_bias = add_qk.input[1] + + key = None + past_key = None + present_key = None + if self.static_kv == 1: + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], + ) + if k_nodes is not None: + transpose_k, reshape_k, _ = k_nodes + key = reshape_k.input[0] + present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]] + for present_key_transpose_node in present_key_transpose_nodes: + present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0]) + if present_key_candidate is not None: + present_key = present_key_candidate.name + break + if present_key is None: + return + if "present_key_cross" not in present_key: + return + else: + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose"], + [1], + ) + if k_nodes is None: + return + transpose_k = k_nodes[0] + + past_key = transpose_k.input[0] + if past_key in output_name_to_node: + return + if "past_key_cross" not in past_key: + return + else: + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "Reshape", "MatMul"], + [1, 0, 1, 0], + ) + if k_nodes is not None: + _, concat_k, reshape_k, _ = k_nodes + key = reshape_k.input[0] + past_key_transpose_node = output_name_to_node[concat_k.input[0]] + past_key = past_key_transpose_node.input[0] + if past_key in output_name_to_node: + return + if "past_key_self" not in past_key: + return + present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]] + for present_key_transpose_node in present_key_transpose_nodes: + # print("present_key_transpose_node:", present_key_transpose_node) + present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0]) + # print("present_key_candidate:", present_key_candidate) + if present_key_candidate is not None: + present_key = present_key_candidate.name + break + if present_key is None: + return + if "present_key_self" not in present_key: + return + else: + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], + ) + if k_nodes is None: + return + _, reshape_k, _ = k_nodes + key = reshape_k.input[0] + present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]] + for present_key_transpose_node in present_key_transpose_nodes: + present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0]) + if present_key_candidate is not None: + present_key = present_key_candidate.name + break + if present_key is None: + return + if "present_key_self" not in present_key: + return + + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "MatMul"], + [0, 0, 0], + ) + if q_nodes is None: + return + + transpose_q, reshape_q, matmul_q = q_nodes + + if matmul_q.input[0] != input_shape_node.input[0]: + return + + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + + if self.static_kv == 1 and past_key is not None: + key = past_key + value = past_value + past_key = None + past_value = None + + new_node = self.create_mha_node( + matmul_q.output[0], + key, + value, + mask_index, + res_pos_bias, + past_key, + past_value, + reshape_qkv.output[0], + present_key, + present_value, + q_num_heads, + q_hidden_size, + ) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend(qkv_nodes[1:]) + self.nodes_to_remove.extend(qk_nodes) + self.nodes_to_remove.extend(k_nodes[:-1]) + if v_nodes is not None: + self.nodes_to_remove.extend(v_nodes[:-1]) + self.nodes_to_remove.extend(q_nodes[:-1]) + + self.prune_graph = True class FusionRelativePositionBiasBlock(Fusion): @@ -135,12 +570,56 @@ class FusionRelativePositionBiasBlock(Fusion): self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name +class FusionSimplifiedLayerNormalization(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "SimplifiedLayerNormalization", "Mul") + + def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + if node.op_type != "Mul": + return + + sim_ln_nodes = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], + [1, 1, 1, 0, 0, 0, 0], + ) + if sim_ln_nodes is None: + return + + pow_node = sim_ln_nodes[-2] + if not self.model.find_constant_input(pow_node, 2.0) == 1: + return + + root_input = pow_node.input[0] + + mul_node_1 = sim_ln_nodes[0] + if root_input != mul_node_1.input[0]: + return + + second_add_node = sim_ln_nodes[3] + i, add_weight = self.model.get_constant_input(second_add_node) + if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: + logger.warning(f"epsilon value is not expeced: {add_weight}") + return + + self.nodes_to_remove.extend(sim_ln_nodes[:-1]) + + normalize_node = helper.make_node( + "SimplifiedLayerNormalization", + inputs=[root_input, node.input[0]], + outputs=[node.output[0]], + name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), + ) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) + normalize_node.attribute.extend([helper.make_attribute("axis", int(-1))]) + normalize_node.attribute.extend([helper.make_attribute("stash_type", int(1))]) + self.nodes_to_add.append(normalize_node) + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + + class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization): def __init__(self, model: OnnxModel): super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization") - self.shape_infer_helper = self.model.infer_runtime_shape( - {"batch_size": 2, "seq_len": 1, "encode_sequence_length": 8, "past_decode_sequence_length": 4}, update=True - ) def fuse(self, node, input_name_to_nodes, output_name_to_node): super().fuse(node, input_name_to_nodes, output_name_to_node) @@ -151,6 +630,7 @@ class T5OnnxModel(BertOnnxModel): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask) + self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self) self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self) # TODO: consider retrive max_distance from model. # math.log(max_distance / (num_buckets // 2)) @@ -159,6 +639,9 @@ class T5OnnxModel(BertOnnxModel): def fuse_attention(self): self.attention_fusion.apply() + def fuse_layer_norm(self): + self.layer_norm_fusion.apply() + def fuse_skip_layer_norm(self): self.skip_layer_norm_fusion.apply() @@ -234,8 +717,11 @@ class T5OnnxModel(BertOnnxModel): nodes_to_remove.append(node) self.remove_nodes(nodes_to_remove) - def postprocess(self): + def preprocess(self): + self.adjust_reshape_and_expand() self.rpb_fusion.apply() + + def postprocess(self): # remove get_extended_attention_mask() since it generates all zeros. self.remove_extended_mask_decoder_init() self.remove_extended_mask_decoder() diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index 657d52cc15..d3de8a50c8 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -25,11 +25,11 @@ else: class TestFusion(unittest.TestCase): def verify_fusion(self, optimized_model, expected_model_filename): - optimized_model.topological_sort() + optimized_model.topological_sort(is_deterministic=True) expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename) expected_model = OnnxModel(onnx.load(expected_model_path)) - expected_model.topological_sort() + expected_model.topological_sort(is_deterministic=True) self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))