From 56b37e55e52b72148fab222cb3bdbe5be29c2dd1 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 24 Aug 2021 18:13:46 -0700 Subject: [PATCH] Add new transformers model type: Bart (#8698) * update * bart-base encoder attention fusion * update * update * update * update * update * yapf * review comments --- .../tools/transformers/fusion_reshape.py | 37 +-- .../tools/transformers/huggingface_models.py | 8 +- .../tools/transformers/onnx_model_bart.py | 256 ++++++++++++++++++ .../python/tools/transformers/optimizer.py | 2 + 4 files changed, 282 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/onnx_model_bart.py diff --git a/onnxruntime/python/tools/transformers/fusion_reshape.py b/onnxruntime/python/tools/transformers/fusion_reshape.py index a325014767..3d336d5ce2 100644 --- a/onnxruntime/python/tools/transformers/fusion_reshape.py +++ b/onnxruntime/python/tools/transformers/fusion_reshape.py @@ -3,11 +3,11 @@ # Licensed under the MIT License. #-------------------------------------------------------------------------- +from fusion_base import Fusion from logging import getLogger +import numpy as np from onnx import helper, numpy_helper, TensorProto from onnx_model import OnnxModel -from fusion_base import Fusion -import numpy as np logger = getLogger(__name__) @@ -16,6 +16,23 @@ class FusionReshape(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "Reshape", "Reshape") + def replace_reshape_node(self, shape, reshape_node, concat_node): + shape_value = np.asarray(shape, dtype=np.int64) + constant_shape_name = self.model.create_node_name('Constant', 'constant_shape') + new_node = helper.make_node('Constant', + inputs=[], + outputs=[constant_shape_name], + value=helper.make_tensor(name='const_tensor', + data_type=TensorProto.INT64, + dims=shape_value.shape, + vals=bytes(shape_value), + raw=True)) + reshape_node.input[1] = constant_shape_name + reshape_node.name = self.model.create_node_name('Reshape', 'Reshape_Fuse') + self.nodes_to_remove.extend([concat_node]) + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): if reshape_node.input[1] not in output_name_to_node: return @@ -117,23 +134,9 @@ class FusionReshape(Fusion): if not same_shape_input: return - shape_value = np.asarray(shape, dtype=np.int64) + self.replace_reshape_node(shape, reshape_node, concat_node) - constant_shape_name = self.model.create_node_name('Constant', 'constant_shape') - new_node = helper.make_node('Constant', - inputs=[], - outputs=[constant_shape_name], - value=helper.make_tensor(name='const_tensor', - data_type=TensorProto.INT64, - dims=shape_value.shape, - vals=bytes(shape_value), - raw=True)) - reshape_node.input[1] = constant_shape_name - reshape_node.name = self.model.create_node_name('Reshape', 'Reshape_Fuse') - self.nodes_to_remove.extend([concat_node]) self.nodes_to_remove.extend(path0) self.nodes_to_remove.extend(path1) self.nodes_to_remove.extend(path2) self.nodes_to_remove.extend(path3) - self.nodes_to_add.append(new_node) - self.node_name_to_graph_name[new_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/huggingface_models.py b/onnxruntime/python/tools/transformers/huggingface_models.py index a2cd823e9f..051480ebb0 100644 --- a/onnxruntime/python/tools/transformers/huggingface_models.py +++ b/onnxruntime/python/tools/transformers/huggingface_models.py @@ -87,10 +87,10 @@ MODELS = { "flaubert/flaubert_base_cased": (["input_ids"], 11, False, "bert"), #"flaubert/flaubert_large_cased": (["input_ids"], 11, False, "bert"), # Bart - "facebook/bart-large": (["input_ids"], 11, False, "bert"), - "facebook/bart-base": (["input_ids"], 11, False, "bert"), - "facebook/bart-large-mnli": (["input_ids"], 11, False, "bert"), - "facebook/bart-large-cnn": (["input_ids"], 11, False, "bert"), + "facebook/bart-large": (["input_ids", "attention_mask"], 11, False, "bart"), + "facebook/bart-base": (["input_ids", "attention_mask"], 11, False, "bart"), + "facebook/bart-large-mnli": (["input_ids", "attention_mask"], 11, False, "bart"), + "facebook/bart-large-cnn": (["input_ids", "attention_mask"], 11, False, "bart"), # DialoGPT "microsoft/DialoGPT-small": (["input_ids"], 11, False, "gpt2"), diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py new file mode 100644 index 0000000000..7ba3104c19 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -0,0 +1,256 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +#-------------------------------------------------------------------------- +import logging +from fusion_attention import FusionAttention, AttentionMask +from fusion_reshape import FusionReshape +from onnx import numpy_helper +from onnx_model import OnnxModel +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + + +class FusionBartEncoderAttention(FusionAttention): + """ + Fuse Bart Attention subgraph into one Attention node. + """ + def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def check_runtime_shape_path(self, reshape_qkv_2, reshape_qkv_1, reshape_q_2, reshape_k_2, reshape_v_2, root_input): + concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ['Concat'], [1]) + if concat_qkv_2_path is None: + return False + concat_qkv_2 = concat_qkv_2_path[0] + + reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ['Unsqueeze', 'Gather', 'Shape'], [0, 0, 0]) + reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ['Unsqueeze', 'Gather', 'Shape'], [1, 0, 0]) + reshape_qkv_2_path_3 = self.model.match_parent_path(concat_qkv_2, ['Unsqueeze', 'Gather', 'Shape'], [2, 0, 0]) + if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None or reshape_qkv_2_path_3 is None: + return False + + _, gather_1, shape_1 = reshape_qkv_2_path_1 + _, gather_2, shape_2 = reshape_qkv_2_path_2 + _, _, shape_3 = reshape_qkv_2_path_3 + + if shape_1.input[0] != root_input or shape_2.input[0] != root_input or shape_3.input[0] != root_input: + return False + + reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ['Concat', 'Unsqueeze', 'Gather'], [1, 0, 0]) + reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ['Concat', 'Unsqueeze', 'Gather'], [1, 2, 0]) + if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None: + return False + if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name: + return False + + reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ['Concat', 'Unsqueeze', 'Mul'], [1, 0, 0]) + reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ['Concat', 'Unsqueeze', 'Mul'], [1, 0, 0]) + reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ['Concat', 'Unsqueeze', 'Mul'], [1, 0, 0]) + if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None: + return False + + mul_q = reshape_q_2_path[-1] + mul_k = reshape_k_2_path[-1] + mul_v = reshape_v_2_path[-1] + + gather_1_out = gather_1.output[0] + if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: + return False + + return True + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + qkv_nodes = self.model.match_parent_path(normalize_node, + ['Add', 'MatMul', 'Reshape', 'Transpose', 'Reshape', 'MatMul'], + [None, 1, 0, 0, 0, 0]) + if qkv_nodes is not None: + (add_out, matmul_out, reshape_qkv_2, transpose_qkv, reshape_qkv_1, matmul_qkv) = qkv_nodes + else: + return + + other_inputs = [] + for i, input in enumerate(normalize_node.input): + if input not in output_name_to_node: + continue + if input == qkv_nodes[0].output[0]: + continue + other_inputs.append(input) + if len(other_inputs) != 1: + return + + root_input = other_inputs[0] + children = input_name_to_nodes[root_input] + children_types = [child.op_type for child in children] + if children_types.count('MatMul') != 3: + return + + v_nodes = self.model.match_parent_path(matmul_qkv, ['Reshape', 'Transpose', 'Reshape', 'Add', 'MatMul'], + [1, 0, 0, 0, None]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return + (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'MatMul'], [0, 0]) + if qk_nodes is not None: + _, matmul_qk = qk_nodes + else: + return + + q_nodes = self.model.match_parent_path(matmul_qk, ['Reshape', 'Transpose', 'Reshape', 'Mul', 'Add', 'MatMul'], + [0, 0, 0, 0, 0, 1]) + if q_nodes is not None: + reshape_q_2, _, reshape_q_1, _, add_q, matmul_q = q_nodes + else: + return + + k_nodes = self.model.match_parent_path(matmul_qk, + ['Transpose', 'Reshape', 'Transpose', 'Reshape', 'Add', 'MatMul'], + [1, 0, 0, 0, 0, 1]) + if k_nodes is not None: + _, reshape_k_2, _, reshape_k_1, add_k, matmul_k = k_nodes + else: + return + + if not self.check_runtime_shape_path(reshape_qkv_2, reshape_qkv_1, reshape_q_2, reshape_k_2, reshape_v_2, + root_input): + return + + if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_v.input[0] == root_input: + + mask_nodes = [] + mask_index = None + attention_last_node = reshape_qkv_2 + + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1) + + if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: + logger.debug("fuse_attention: failed to detect num_heads or hidden_size") + return + + new_node = self.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v, + num_heads, hidden_size, root_input, attention_last_node.output[0], + None) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + self.nodes_to_remove.extend(q_nodes) + self.nodes_to_remove.extend(k_nodes) + self.nodes_to_remove.extend(v_nodes) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + self.nodes_to_remove.extend(mask_nodes) + self.prune_graph = True + + +class FusionBartReshape(FusionReshape): + def __init__(self, model: OnnxModel): + super().__init__(model) + + def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): + if reshape_node.input[1] not in output_name_to_node: + return + + concat_node = output_name_to_node[reshape_node.input[1]] + if concat_node.op_type != 'Concat' or len(concat_node.input) != 4: + return + + path0 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [0, 0, 0], + output_name_to_node) + if path0 is None: + return + + (_, gather_0, shape_0) = path0 + + shape = [] + gather_value = self.model.get_constant_value(gather_0.input[1]) + if gather_value == 0: + shape.append(0) + + path1 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [1, 0, 0], + output_name_to_node) + if path1 is None: + input_1_proto = self.model.get_initializer(concat_node.input[1]) + input_2_proto = self.model.get_initializer(concat_node.input[2]) + input_3_proto = self.model.get_initializer(concat_node.input[3]) + if input_1_proto is None or input_2_proto is None or input_3_proto is None: + return + + input_1 = numpy_helper.to_array(input_1_proto) + input_2 = numpy_helper.to_array(input_2_proto) + input_3 = numpy_helper.to_array(input_3_proto) + if len(input_1) != 1 or len(input_2) != 1 or len(input_3) != 1: + return + + if not (input_1[0] == -1 and input_2[0] > 0 and input_3[0] > 0): + return + + shape.extend(input_1) + shape.extend(input_2) + shape.extend(input_3) + gemm_path = self.model.match_parent_path(reshape_node, ['Add', 'MatMul'], [0, 1], output_name_to_node) + if gemm_path is None: + return + + top_matmul = gemm_path[-1] + root_input = top_matmul.input[0] + if shape_0.input[0] != root_input: + return + + self.replace_reshape_node(shape, reshape_node, concat_node) + else: + (_, gather_1, shape_1) = path1 + + gather_value = self.model.get_constant_value(gather_1.input[1]) + if gather_value == 1: + shape.append(0) + + input_2_proto = self.model.get_initializer(concat_node.input[2]) + input_3_proto = self.model.get_initializer(concat_node.input[3]) + if input_2_proto is None or input_3_proto is None: + return + + input_2 = numpy_helper.to_array(input_2_proto) + input_3 = numpy_helper.to_array(input_3_proto) + if len(input_2) != 1 or len(input_3) != 1: + return + + if not (input_2[0] > 0 and input_3[0] > 0): + return + + shape.extend(input_2) + shape.extend(input_3) + gemm_path = self.model.match_parent_path(reshape_node, ['Mul', 'Add', 'MatMul'], [0, 0, 1], + output_name_to_node) + if gemm_path is None: + return + + top_matmul = gemm_path[-1] + root_input = top_matmul.input[0] + if shape_0.input[0] != root_input or shape_1.input[0] != root_input: + return + + self.replace_reshape_node(shape, reshape_node, concat_node) + + +class BartOnnxModel(BertOnnxModel): + def __init__(self, model, num_heads, hidden_size): + super().__init__(model, num_heads, hidden_size) + self.attention_mask = AttentionMask(self) + self.attention_fusion = FusionBartEncoderAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + self.bart_reshape_fusion_preprocess = FusionBartReshape(self) + + def fuse_attention(self): + self.attention_fusion.apply() + + def preprocess(self): + self.adjust_reshape_and_expand() + self.bart_reshape_fusion_preprocess.apply() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 541311485a..9399c60793 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -27,6 +27,7 @@ import numpy as np from typing import Dict from collections import deque from onnx import ModelProto, TensorProto, numpy_helper, load_model +from onnx_model_bart import BartOnnxModel from onnx_model_bert import BertOnnxModel, BertOptimizationOptions from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_bert_keras import BertOnnxModelKeras @@ -37,6 +38,7 @@ logger = logging.getLogger(__name__) # Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level MODEL_TYPES = { + "bart": (BartOnnxModel, "pytorch", 1), "bert": (BertOnnxModel, "pytorch", 1), "bert_tf": (BertOnnxModelTF, "tf2onnx", 0), "bert_keras": (BertOnnxModelKeras, "keras2onnx", 0),