diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 689235b630..dbc939bce2 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -198,6 +198,7 @@ class SymbolicShapeInference: "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, "PythonOp": self._infer_PythonOp, + "SimplifiedLayerNormalization": self._infer_LayerNormalization, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, "GroupNorm": self._infer_GroupNorm, @@ -433,7 +434,9 @@ class SymbolicShapeInference: "GemmFastGelu", "LayerNormalization", "LongformerAttention", + "SimplifiedLayerNormalization", "SkipLayerNormalization", + "SkipSimplifiedLayerNormalization", "PythonOp", "MultiHeadAttention", "GroupNorm", diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 122a574064..a106d906d0 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -483,7 +483,7 @@ def t5_to_onnx(args: argparse.Namespace): Path(args.output).parent, use_gpu=args.use_gpu, use_external_data_format=args.use_external_data_format, - optimize_onnx=False, + optimize_onnx=(args.precision != Precision.FLOAT16), precision=args.precision, verbose=False, use_decoder_start_token=False, diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index 5a32415aba..7c54649553 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -19,8 +19,13 @@ class FusionSkipLayerNormalization(Fusion): Note: This fusion does not check the input shape of Add and LayerNormalization. """ - def __init__(self, model: OnnxModel): - super().__init__(model, "SkipLayerNormalization", "LayerNormalization") + def __init__( + self, + model: OnnxModel, + fused_op_type: str = "SkipLayerNormalization", + search_op_types: str = "LayerNormalization", + ): + super().__init__(model, fused_op_type, search_op_types) # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True) @@ -44,6 +49,9 @@ class FusionSkipLayerNormalization(Fusion): if len(self.model.get_parents(add)) != 2: return + # Root Mean Square Layer Normalization + simplified = node.op_type == "SimplifiedLayerNormalization" + if self.shape_infer_helper is not None: if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): logger.debug( @@ -89,12 +97,16 @@ class FusionSkipLayerNormalization(Fusion): ): self.nodes_to_remove.extend([add, node]) - inputs = [add.input[0], add.input[1], node.input[1], node.input[2]] + inputs = ( + [add.input[0], add.input[1], node.input[1], node.input[2]] + if not simplified + else [add.input[0], add.input[1], node.input[1]] + ) normalize_node = helper.make_node( - "SkipLayerNormalization", + self.fused_op_type, inputs=inputs, outputs=outputs, - name=self.model.create_node_name("SkipLayerNormalization", name_prefix="SkipLayerNorm"), + name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"), ) normalize_node.domain = "com.microsoft" diff --git a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py index ae6995dd77..eff24f58a0 100644 --- a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py @@ -203,6 +203,7 @@ def export_onnx_models( config.hidden_size, use_external_data_format, auto_mixed_precision=not disable_auto_mixed_precision, + use_gpu=use_gpu, ) else: logger.info(f"Skip optimizing: existed ONNX model {onnx_path}") diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py index 4d853a6544..c91c0da178 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py @@ -228,15 +228,25 @@ class T5Helper: hidden_size: int, use_external_data_format: bool = False, auto_mixed_precision: bool = True, + use_gpu: bool = False, ): """Optimize ONNX model with an option to convert it to use mixed precision.""" + + from fusion_options import FusionOptions + + optimization_options = None + if not use_gpu: + # Currently there is no SkipSimplifiedLayerNorm cpu kernel + optimization_options = FusionOptions("t5") + optimization_options.enable_skip_layer_norm = False + m = optimize_model( onnx_model_path, - model_type="bert", # TODO: support optimization for t5 + model_type="t5", num_heads=num_attention_heads, hidden_size=hidden_size, - opt_level=0, - optimization_options=None, + opt_level=2 if not is_float16 and not use_external_data_format else 0, + optimization_options=optimization_options, use_gpu=False, ) if is_float16: diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py new file mode 100644 index 0000000000..528467b9f2 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from typing import Union + +from fusion_attention import AttentionMask, FusionAttention +from fusion_base import Fusion +from fusion_skiplayernorm import FusionSkipLayerNormalization +from onnx import NodeProto +from onnx_model import OnnxModel +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + +# TODO: Support decoder self/cross attention fusion and encoder self attention fusion +class FusionT5Attention(FusionAttention): + """ + Fuse T5 Attention subgraph into one Attention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def create_attention_node( + self, + mask_index: str, + matmul: NodeProto, + add: NodeProto, + num_heads: int, + hidden_size: int, + input: str, + output: str, + add_qk_str: str, + ) -> Union[NodeProto, None]: + # Not implemented yet + return None + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # Not implemented yet + return + + +# It's much easier to export it with the custom op. TODO: revisit later +class FusionRelativePositionBiasBlock(Fusion): + def __init__(self, model: OnnxModel, max_distance: int, is_bidirectional: bool): + super().__init__(model, "RelativePositionBias", "Add") + self.max_distance = max_distance + self.is_bidirectional = is_bidirectional + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + # Not implemented yet + return + + +class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization): + def __init__(self, model: OnnxModel): + super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization") + self.shape_infer_helper = self.model.infer_runtime_shape( + {"batch_size": 2, "seq_len": 1, "encode_sequence_length": 8, "past_decode_sequence_length": 4}, update=True + ) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + super().fuse(node, input_name_to_nodes, output_name_to_node) + + +class T5OnnxModel(BertOnnxModel): + def __init__(self, model, num_heads, hidden_size): + super().__init__(model, num_heads, hidden_size) + self.attention_mask = AttentionMask(self) + self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask) + self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self) + # TODO: hardcode for now. double check later + self.rpb_fusion = FusionRelativePositionBiasBlock(self, 32, True) + + def fuse_attention(self): + self.attention_fusion.apply() + + def fuse_skip_layer_norm(self): + self.skip_layer_norm_fusion.apply() + + def postprocess(self): + self.rpb_fusion.apply() + self.clean_graph() + self.prune_graph() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 56076eedda..a18535c105 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -30,6 +30,7 @@ from onnx_model_bert import BertOnnxModel from onnx_model_bert_keras import BertOnnxModelKeras from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_gpt2 import Gpt2OnnxModel +from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel from onnx_model_unet import UnetOnnxModel @@ -49,6 +50,7 @@ MODEL_TYPES = { ), # might add a class for GPT2OnnxModel for TF later. "tnlr": (TnlrOnnxModel, "pytorch", 1), "unet": (UnetOnnxModel, "pytorch", 1), + "t5": (T5OnnxModel, "pytorch", 2), } @@ -248,7 +250,7 @@ def optimize_model( else [ "MatMulScaleFusion", "MatMulAddFusion", - "SimplifiedLayerNormFusion", + "MatmulTransposeFusion", "GemmActivationFusion", "BiasSoftmaxFusion", ]