mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Add SLN support for t5 model with beam search (#14429)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
parent
638f21b969
commit
999e5bf45e
7 changed files with 130 additions and 10 deletions
|
|
@ -198,6 +198,7 @@ class SymbolicShapeInference:
|
|||
"LayerNormalization": self._infer_LayerNormalization,
|
||||
"LongformerAttention": self._infer_LongformerAttention,
|
||||
"PythonOp": self._infer_PythonOp,
|
||||
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
|
||||
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
|
||||
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
|
||||
"GroupNorm": self._infer_GroupNorm,
|
||||
|
|
@ -433,7 +434,9 @@ class SymbolicShapeInference:
|
|||
"GemmFastGelu",
|
||||
"LayerNormalization",
|
||||
"LongformerAttention",
|
||||
"SimplifiedLayerNormalization",
|
||||
"SkipLayerNormalization",
|
||||
"SkipSimplifiedLayerNormalization",
|
||||
"PythonOp",
|
||||
"MultiHeadAttention",
|
||||
"GroupNorm",
|
||||
|
|
|
|||
|
|
@ -483,7 +483,7 @@ def t5_to_onnx(args: argparse.Namespace):
|
|||
Path(args.output).parent,
|
||||
use_gpu=args.use_gpu,
|
||||
use_external_data_format=args.use_external_data_format,
|
||||
optimize_onnx=False,
|
||||
optimize_onnx=(args.precision != Precision.FLOAT16),
|
||||
precision=args.precision,
|
||||
verbose=False,
|
||||
use_decoder_start_token=False,
|
||||
|
|
|
|||
|
|
@ -19,8 +19,13 @@ class FusionSkipLayerNormalization(Fusion):
|
|||
Note: This fusion does not check the input shape of Add and LayerNormalization.
|
||||
"""
|
||||
|
||||
def __init__(self, model: OnnxModel):
|
||||
super().__init__(model, "SkipLayerNormalization", "LayerNormalization")
|
||||
def __init__(
|
||||
self,
|
||||
model: OnnxModel,
|
||||
fused_op_type: str = "SkipLayerNormalization",
|
||||
search_op_types: str = "LayerNormalization",
|
||||
):
|
||||
super().__init__(model, fused_op_type, search_op_types)
|
||||
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
|
||||
self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)
|
||||
|
||||
|
|
@ -44,6 +49,9 @@ class FusionSkipLayerNormalization(Fusion):
|
|||
if len(self.model.get_parents(add)) != 2:
|
||||
return
|
||||
|
||||
# Root Mean Square Layer Normalization
|
||||
simplified = node.op_type == "SimplifiedLayerNormalization"
|
||||
|
||||
if self.shape_infer_helper is not None:
|
||||
if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
|
||||
logger.debug(
|
||||
|
|
@ -89,12 +97,16 @@ class FusionSkipLayerNormalization(Fusion):
|
|||
):
|
||||
self.nodes_to_remove.extend([add, node])
|
||||
|
||||
inputs = [add.input[0], add.input[1], node.input[1], node.input[2]]
|
||||
inputs = (
|
||||
[add.input[0], add.input[1], node.input[1], node.input[2]]
|
||||
if not simplified
|
||||
else [add.input[0], add.input[1], node.input[1]]
|
||||
)
|
||||
normalize_node = helper.make_node(
|
||||
"SkipLayerNormalization",
|
||||
self.fused_op_type,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=self.model.create_node_name("SkipLayerNormalization", name_prefix="SkipLayerNorm"),
|
||||
name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"),
|
||||
)
|
||||
normalize_node.domain = "com.microsoft"
|
||||
|
||||
|
|
|
|||
|
|
@ -203,6 +203,7 @@ def export_onnx_models(
|
|||
config.hidden_size,
|
||||
use_external_data_format,
|
||||
auto_mixed_precision=not disable_auto_mixed_precision,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Skip optimizing: existed ONNX model {onnx_path}")
|
||||
|
|
|
|||
|
|
@ -228,15 +228,25 @@ class T5Helper:
|
|||
hidden_size: int,
|
||||
use_external_data_format: bool = False,
|
||||
auto_mixed_precision: bool = True,
|
||||
use_gpu: bool = False,
|
||||
):
|
||||
"""Optimize ONNX model with an option to convert it to use mixed precision."""
|
||||
|
||||
from fusion_options import FusionOptions
|
||||
|
||||
optimization_options = None
|
||||
if not use_gpu:
|
||||
# Currently there is no SkipSimplifiedLayerNorm cpu kernel
|
||||
optimization_options = FusionOptions("t5")
|
||||
optimization_options.enable_skip_layer_norm = False
|
||||
|
||||
m = optimize_model(
|
||||
onnx_model_path,
|
||||
model_type="bert", # TODO: support optimization for t5
|
||||
model_type="t5",
|
||||
num_heads=num_attention_heads,
|
||||
hidden_size=hidden_size,
|
||||
opt_level=0,
|
||||
optimization_options=None,
|
||||
opt_level=2 if not is_float16 and not use_external_data_format else 0,
|
||||
optimization_options=optimization_options,
|
||||
use_gpu=False,
|
||||
)
|
||||
if is_float16:
|
||||
|
|
|
|||
92
onnxruntime/python/tools/transformers/onnx_model_t5.py
Normal file
92
onnxruntime/python/tools/transformers/onnx_model_t5.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from fusion_attention import AttentionMask, FusionAttention
|
||||
from fusion_base import Fusion
|
||||
from fusion_skiplayernorm import FusionSkipLayerNormalization
|
||||
from onnx import NodeProto
|
||||
from onnx_model import OnnxModel
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: Support decoder self/cross attention fusion and encoder self attention fusion
|
||||
class FusionT5Attention(FusionAttention):
|
||||
"""
|
||||
Fuse T5 Attention subgraph into one Attention node.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: OnnxModel,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
attention_mask: AttentionMask,
|
||||
):
|
||||
super().__init__(model, hidden_size, num_heads, attention_mask)
|
||||
|
||||
def create_attention_node(
|
||||
self,
|
||||
mask_index: str,
|
||||
matmul: NodeProto,
|
||||
add: NodeProto,
|
||||
num_heads: int,
|
||||
hidden_size: int,
|
||||
input: str,
|
||||
output: str,
|
||||
add_qk_str: str,
|
||||
) -> Union[NodeProto, None]:
|
||||
# Not implemented yet
|
||||
return None
|
||||
|
||||
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
||||
# Not implemented yet
|
||||
return
|
||||
|
||||
|
||||
# It's much easier to export it with the custom op. TODO: revisit later
|
||||
class FusionRelativePositionBiasBlock(Fusion):
|
||||
def __init__(self, model: OnnxModel, max_distance: int, is_bidirectional: bool):
|
||||
super().__init__(model, "RelativePositionBias", "Add")
|
||||
self.max_distance = max_distance
|
||||
self.is_bidirectional = is_bidirectional
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
# Not implemented yet
|
||||
return
|
||||
|
||||
|
||||
class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
|
||||
def __init__(self, model: OnnxModel):
|
||||
super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
|
||||
self.shape_infer_helper = self.model.infer_runtime_shape(
|
||||
{"batch_size": 2, "seq_len": 1, "encode_sequence_length": 8, "past_decode_sequence_length": 4}, update=True
|
||||
)
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
super().fuse(node, input_name_to_nodes, output_name_to_node)
|
||||
|
||||
|
||||
class T5OnnxModel(BertOnnxModel):
|
||||
def __init__(self, model, num_heads, hidden_size):
|
||||
super().__init__(model, num_heads, hidden_size)
|
||||
self.attention_mask = AttentionMask(self)
|
||||
self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
|
||||
self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
|
||||
# TODO: hardcode for now. double check later
|
||||
self.rpb_fusion = FusionRelativePositionBiasBlock(self, 32, True)
|
||||
|
||||
def fuse_attention(self):
|
||||
self.attention_fusion.apply()
|
||||
|
||||
def fuse_skip_layer_norm(self):
|
||||
self.skip_layer_norm_fusion.apply()
|
||||
|
||||
def postprocess(self):
|
||||
self.rpb_fusion.apply()
|
||||
self.clean_graph()
|
||||
self.prune_graph()
|
||||
|
|
@ -30,6 +30,7 @@ from onnx_model_bert import BertOnnxModel
|
|||
from onnx_model_bert_keras import BertOnnxModelKeras
|
||||
from onnx_model_bert_tf import BertOnnxModelTF
|
||||
from onnx_model_gpt2 import Gpt2OnnxModel
|
||||
from onnx_model_t5 import T5OnnxModel
|
||||
from onnx_model_tnlr import TnlrOnnxModel
|
||||
from onnx_model_unet import UnetOnnxModel
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ MODEL_TYPES = {
|
|||
), # might add a class for GPT2OnnxModel for TF later.
|
||||
"tnlr": (TnlrOnnxModel, "pytorch", 1),
|
||||
"unet": (UnetOnnxModel, "pytorch", 1),
|
||||
"t5": (T5OnnxModel, "pytorch", 2),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -248,7 +250,7 @@ def optimize_model(
|
|||
else [
|
||||
"MatMulScaleFusion",
|
||||
"MatMulAddFusion",
|
||||
"SimplifiedLayerNormFusion",
|
||||
"MatmulTransposeFusion",
|
||||
"GemmActivationFusion",
|
||||
"BiasSoftmaxFusion",
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue