Add SLN support for t5 model with beam search (#14429)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2023-02-03 11:38:18 -08:00 committed by GitHub
parent 638f21b969
commit 999e5bf45e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 130 additions and 10 deletions

View file

@ -198,6 +198,7 @@ class SymbolicShapeInference:
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
"PythonOp": self._infer_PythonOp,
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
"GroupNorm": self._infer_GroupNorm,
@ -433,7 +434,9 @@ class SymbolicShapeInference:
"GemmFastGelu",
"LayerNormalization",
"LongformerAttention",
"SimplifiedLayerNormalization",
"SkipLayerNormalization",
"SkipSimplifiedLayerNormalization",
"PythonOp",
"MultiHeadAttention",
"GroupNorm",

View file

@ -483,7 +483,7 @@ def t5_to_onnx(args: argparse.Namespace):
Path(args.output).parent,
use_gpu=args.use_gpu,
use_external_data_format=args.use_external_data_format,
optimize_onnx=False,
optimize_onnx=(args.precision != Precision.FLOAT16),
precision=args.precision,
verbose=False,
use_decoder_start_token=False,

View file

@ -19,8 +19,13 @@ class FusionSkipLayerNormalization(Fusion):
Note: This fusion does not check the input shape of Add and LayerNormalization.
"""
def __init__(self, model: OnnxModel):
super().__init__(model, "SkipLayerNormalization", "LayerNormalization")
def __init__(
self,
model: OnnxModel,
fused_op_type: str = "SkipLayerNormalization",
search_op_types: str = "LayerNormalization",
):
super().__init__(model, fused_op_type, search_op_types)
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)
@ -44,6 +49,9 @@ class FusionSkipLayerNormalization(Fusion):
if len(self.model.get_parents(add)) != 2:
return
# Root Mean Square Layer Normalization
simplified = node.op_type == "SimplifiedLayerNormalization"
if self.shape_infer_helper is not None:
if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
logger.debug(
@ -89,12 +97,16 @@ class FusionSkipLayerNormalization(Fusion):
):
self.nodes_to_remove.extend([add, node])
inputs = [add.input[0], add.input[1], node.input[1], node.input[2]]
inputs = (
[add.input[0], add.input[1], node.input[1], node.input[2]]
if not simplified
else [add.input[0], add.input[1], node.input[1]]
)
normalize_node = helper.make_node(
"SkipLayerNormalization",
self.fused_op_type,
inputs=inputs,
outputs=outputs,
name=self.model.create_node_name("SkipLayerNormalization", name_prefix="SkipLayerNorm"),
name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"),
)
normalize_node.domain = "com.microsoft"

View file

@ -203,6 +203,7 @@ def export_onnx_models(
config.hidden_size,
use_external_data_format,
auto_mixed_precision=not disable_auto_mixed_precision,
use_gpu=use_gpu,
)
else:
logger.info(f"Skip optimizing: existed ONNX model {onnx_path}")

View file

@ -228,15 +228,25 @@ class T5Helper:
hidden_size: int,
use_external_data_format: bool = False,
auto_mixed_precision: bool = True,
use_gpu: bool = False,
):
"""Optimize ONNX model with an option to convert it to use mixed precision."""
from fusion_options import FusionOptions
optimization_options = None
if not use_gpu:
# Currently there is no SkipSimplifiedLayerNorm cpu kernel
optimization_options = FusionOptions("t5")
optimization_options.enable_skip_layer_norm = False
m = optimize_model(
onnx_model_path,
model_type="bert", # TODO: support optimization for t5
model_type="t5",
num_heads=num_attention_heads,
hidden_size=hidden_size,
opt_level=0,
optimization_options=None,
opt_level=2 if not is_float16 and not use_external_data_format else 0,
optimization_options=optimization_options,
use_gpu=False,
)
if is_float16:

View file

@ -0,0 +1,92 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Union
from fusion_attention import AttentionMask, FusionAttention
from fusion_base import Fusion
from fusion_skiplayernorm import FusionSkipLayerNormalization
from onnx import NodeProto
from onnx_model import OnnxModel
from onnx_model_bert import BertOnnxModel
logger = logging.getLogger(__name__)
# TODO: Support decoder self/cross attention fusion and encoder self attention fusion
class FusionT5Attention(FusionAttention):
"""
Fuse T5 Attention subgraph into one Attention node.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)
def create_attention_node(
self,
mask_index: str,
matmul: NodeProto,
add: NodeProto,
num_heads: int,
hidden_size: int,
input: str,
output: str,
add_qk_str: str,
) -> Union[NodeProto, None]:
# Not implemented yet
return None
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Not implemented yet
return
# It's much easier to export it with the custom op. TODO: revisit later
class FusionRelativePositionBiasBlock(Fusion):
def __init__(self, model: OnnxModel, max_distance: int, is_bidirectional: bool):
super().__init__(model, "RelativePositionBias", "Add")
self.max_distance = max_distance
self.is_bidirectional = is_bidirectional
def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Not implemented yet
return
class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
def __init__(self, model: OnnxModel):
super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
self.shape_infer_helper = self.model.infer_runtime_shape(
{"batch_size": 2, "seq_len": 1, "encode_sequence_length": 8, "past_decode_sequence_length": 4}, update=True
)
def fuse(self, node, input_name_to_nodes, output_name_to_node):
super().fuse(node, input_name_to_nodes, output_name_to_node)
class T5OnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
# TODO: hardcode for now. double check later
self.rpb_fusion = FusionRelativePositionBiasBlock(self, 32, True)
def fuse_attention(self):
self.attention_fusion.apply()
def fuse_skip_layer_norm(self):
self.skip_layer_norm_fusion.apply()
def postprocess(self):
self.rpb_fusion.apply()
self.clean_graph()
self.prune_graph()

View file

@ -30,6 +30,7 @@ from onnx_model_bert import BertOnnxModel
from onnx_model_bert_keras import BertOnnxModelKeras
from onnx_model_bert_tf import BertOnnxModelTF
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_t5 import T5OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
from onnx_model_unet import UnetOnnxModel
@ -49,6 +50,7 @@ MODEL_TYPES = {
), # might add a class for GPT2OnnxModel for TF later.
"tnlr": (TnlrOnnxModel, "pytorch", 1),
"unet": (UnetOnnxModel, "pytorch", 1),
"t5": (T5OnnxModel, "pytorch", 2),
}
@ -248,7 +250,7 @@ def optimize_model(
else [
"MatMulScaleFusion",
"MatMulAddFusion",
"SimplifiedLayerNormFusion",
"MatmulTransposeFusion",
"GemmActivationFusion",
"BiasSoftmaxFusion",
]