diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 041d3dd0f6..6bc0f8d4b2 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -131,15 +131,18 @@ class FusionAttention(Fusion): weight = helper.make_tensor(name=attention_node_name + '_qkv_weight', data_type=TensorProto.FLOAT, dims=[self.hidden_size, 3 * self.hidden_size], - vals=bytes(qkv_weight.flatten()), - raw=True) + vals=qkv_weight.flatten().tolist()) + # Sometimes weights and bias are stored in fp16 + if q_weight.data_type == 10: + weight.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(weight).astype(np.float16), weight.name)) self.model.add_initializer(weight) bias = helper.make_tensor(name=attention_node_name + '_qkv_bias', data_type=TensorProto.FLOAT, dims=[3 * self.hidden_size], - vals=bytes(qkv_bias.flatten()), - raw=True) + vals=qkv_bias.flatten().tolist()) + if q_bias.data_type == 10: + bias.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(bias).astype(np.float16), bias.name)) self.model.add_initializer(bias) attnetion_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias'] diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index ade39da01c..ee4f44dda2 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -121,24 +121,28 @@ class FusionLayerNormalizationTF(Fusion): def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ - Layer Norm from Keras in Tensorflow: - +----------------------+ - | | - | v (B) (B) (A) - Add --> ReduceMean --> Sub --> Mul --> ReduceMean --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add - | | | ^ ^ - | | | | | - | +----------------------------------------------------------------------------|-------+ | - | v | - +-------------------------------------------------------------------------------------> Mul--------------------+ + Layer Norm from Tensorflow model(using keras2onnx or tf2onnx): + +------------------------------------+ + | | + | | + (Cast_1) | + | | + | v (B) (B) (A) + Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add + | | | ^ ^ + | | | | | + | +--------------------------------------------------(Cast_2)-------------------------------|-------+ | + | v | + +---------------------------------------------------------------------------------------------------------------> Mul--------------------+ """ return_indice = [] - parent_nodes = self.model.match_parent_path( + _, parent_nodes, return_indice = self.model.match_parent_paths( node, - ['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'], - [ 1, 1, None, 0, 0, 0, None, 0, 0, None], - output_name_to_node, - return_indice=return_indice) # yapf: disable + [(['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'], + [ 1, 1, None, 0, 0, 0, None, 0, 0, None]), + (['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'Cast', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'], + [ 1, 1, None, 0, 0, 0, 0, None, 0, 0, None])], + output_name_to_node) # yapf: disable if parent_nodes is None: return @@ -148,24 +152,35 @@ class FusionLayerNormalizationTF(Fusion): logger.debug("return indice is exepected in [0, 1], but got {return_indice}") return - sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes + sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0 = parent_nodes[:6] + reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:] + + cast_node_3 = None + if len(parent_nodes) == 11: + cast_node_3 = parent_nodes[6] + assert(cast_node_3.op_type == 'Cast') mul_node_3 = self.model.match_parent(node, 'Mul', 0, output_name_to_node) if mul_node_3 is None: logger.debug("mul_node_3 not found") return - root_node = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node) + node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node) + root_node = node_before_reduce if cast_node_3 is None else self.model.get_parent(node_before_reduce, 0, output_name_to_node) if root_node is None: logger.debug("root node is none") return i, epsilon = self.model.get_constant_input(add_node_0) - if epsilon is None or epsilon <= 0 or epsilon > 1.0E-5: + if epsilon is None or epsilon <= 0 or (epsilon > 1.0E-5 and cast_node_3 is None): logger.debug("epsilon is not matched") return - if reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input: + if cast_node_3 is None and (reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input): + logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node") + return + + if cast_node_3 is not None and (node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input): logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node") return @@ -177,6 +192,14 @@ class FusionLayerNormalizationTF(Fusion): node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1, mul_node_3 ] + + if cast_node_3 is not None: + cast_node_2 = self.model.match_parent(mul_node_0, 'Cast', 0, output_name_to_node) + if cast_node_2 is None: + logger.debug("cast_node_2 not found") + return + subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3]) + if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, self.model.input_name_to_nodes(), self.model.output_name_to_node()): logger.debug("not safe to fuse layer normalization") @@ -189,7 +212,8 @@ class FusionLayerNormalizationTF(Fusion): #TODO: add epsilon attribute fused_node = helper.make_node('LayerNormalization', - inputs=[reduce_mean_node_1.input[0], weight_input, bias_input], + inputs=[mul_node_3.input[0], weight_input, bias_input], outputs=[node.output[0]]) fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) self.nodes_to_add.append(fused_node) + diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py index a13d72485e..6c45958d9b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py @@ -9,7 +9,7 @@ import sys import argparse import numpy as np from collections import deque -from onnx import ModelProto, TensorProto, numpy_helper +from onnx import ModelProto, TensorProto, numpy_helper, helper from onnx_model_bert import BertOnnxModel logger = logging.getLogger(__name__) @@ -295,6 +295,126 @@ class BertOnnxModelTF(BertOnnxModel): self.prune_graph() break + def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node): + for x in [matmul_q, matmul_k, matmul_v]: + root_input = x.input[0] + root_node = output_name_to_node[root_input] + if root_node == parent: + continue + logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}") + return False + + return True + + def fuse_attention(self): + output_name_to_node = self.output_name_to_node() + + nodes_to_remove = [] + attention_count = 0 + + skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization") + for normalize_node in skip_layer_norm_nodes: + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + parent = self.get_parent(normalize_node, 1) + if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]: + parent = self.get_parent(normalize_node, 0) + if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]: + logger.debug("Failed to match parent of normalize_node") + continue + + qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'], + [0, 0, 0, 0, 0]) + if qkv_nodes is None: + qkv_nodes = self.match_parent_path(normalize_node, ['MatMul', 'Reshape', 'Transpose', 'MatMul'], + [1, 0, 0, 0]) + if qkv_nodes is None: + logger.debug("Failed to match qkv nodes") + continue + + (reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes[-3:] + v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("Failed to match v path") + continue + + (transpose_v, reshape_v, add_v, matmul_v) = v_nodes + qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', "Mul", 'MatMul'], [0, 0, 0, 0]) + if qk_nodes is None: + logger.debug("Failed to match qk_paths") + continue + (softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes + + q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, 0]) + if q_nodes is None: + logger.debug("Failed to match q path") + continue + (transpose_q, reshape_q, add_q, matmul_q) = q_nodes + + k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0]) + if k_nodes is None: + logger.debug("Failed to match k path") + continue + (transpose_k, reshape_k, add_k, matmul_k) = k_nodes + + mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Unsqueeze'], [1, 0, 1]) + if mask_nodes is None: + mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Mul'], [1, 0, 1, 0, 0]) + if mask_nodes is None: + logger.debug("Failed to match mask path") + continue + + if not self.has_constant_input(mask_nodes[1], 1): + logger.debug("Sub node expected to have an input with constant value 1.0.") + continue + + # add a squeeze node to convert a 3-d mask to 2-d + squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0]) + squeeze_node_name = "Squeeze_3d_to_2d_mask" + squeeze_output_name = squeeze_node_name + "_output" + if squeeze_node is None and len(mask_nodes) == 5: + mask_input = mask_nodes[-1].input[1] + self.add_node( + helper.make_node("Squeeze", [mask_input], [squeeze_output_name], squeeze_node_name, axes=[1])) + mask_nodes[-1].input[0] = squeeze_output_name + + is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node) + if is_same_root: + mask_index = self.attention_mask.process_mask(squeeze_output_name) + logger.debug("Create an Attention node.") + attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, + add_q, add_k, add_v, parent.output[0], + reshape_qkv.output[0]) + if parent.op_type == 'Reshape': + # Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input + hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1] + tensor = helper.make_tensor( + name=parent.name + "_modified", + data_type=TensorProto.INT64, + dims=[3], + vals=np.int64([[1, -1, hidden_size]]).tobytes(), + raw=True) + self.add_initializer(tensor) + parent.input[1] = parent.name + "_modified" + + if attention_node is None: + continue + + self.add_node(attention_node) + attention_count += 1 + + nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv]) + nodes_to_remove.extend(qk_nodes) + nodes_to_remove.extend(q_nodes) + nodes_to_remove.extend(k_nodes) + nodes_to_remove.extend(v_nodes) + nodes_to_remove.extend(mask_nodes) + else: + logger.debug("Root node not matched.") + continue + self.remove_nodes(nodes_to_remove) + self.update_graph() + logger.info(f"Fused Attention count:{attention_count}") + def preprocess(self): self.remove_identity() self.process_embedding() @@ -315,4 +435,5 @@ class BertOnnxModelTF(BertOnnxModel): def postprocess(self): self.remove_reshape_before_first_attention() - self.prune_graph() + # Temporary work around for the following comment as it will cause topological issues for a bert model + # self.prune_graph() diff --git a/onnxruntime/python/tools/transformers/test_data/other_models/bert_tf2onnx_0.onnx b/onnxruntime/python/tools/transformers/test_data/other_models/bert_tf2onnx_0.onnx new file mode 100644 index 0000000000..a7c9110185 Binary files /dev/null and b/onnxruntime/python/tools/transformers/test_data/other_models/bert_tf2onnx_0.onnx differ diff --git a/onnxruntime/python/tools/transformers/test_optimizer.py b/onnxruntime/python/tools/transformers/test_optimizer.py index 7e736e5f85..5ecfe7597b 100644 --- a/onnxruntime/python/tools/transformers/test_optimizer.py +++ b/onnxruntime/python/tools/transformers/test_optimizer.py @@ -30,6 +30,7 @@ BERT_TEST_MODELS = { "gpt2_past": ('gpt2_pytorch1.5_opset11', 'gpt2_past.onnx'), "gpt2_past_mask": ('FUSION', 'gpt2_past_mask_one_layer.onnx'), "multiple_embed": ('FUSION', 'embed_layer_norm_multiple.onnx'), + "bert_tf2onnx_0": ('other_models', 'bert_tf2onnx_0.onnx') } skip_on_ort_version = pytest.mark.skipif(onnxruntime.__version__ == ('1.3.0'), @@ -297,6 +298,20 @@ class TestBertOptimization(unittest.TestCase): } self.verify_node_count(model, expected_node_count, 'test_multiple_embed') + def test_bert_tf2onnx_0(self): + input = _get_test_model_path('bert_tf2onnx_0') + model = optimize_model(input, 'bert_tf', num_heads=2, hidden_size=8) + expected_node_count = { + 'EmbedLayerNormalization': 0, + 'Attention': 6, + 'Gelu': 0, + 'FastGelu': 6, + 'BiasGelu': 0, + 'LayerNormalization': 0, + 'SkipLayerNormalization': 13 + } + self.verify_node_count(model, expected_node_count, 'test_bert_tf2onnx_0') + def test_huggingface_bert_fusion(self): self.test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1) self.test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=2) @@ -343,4 +358,4 @@ class TestBertOptimization(unittest.TestCase): if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file