diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 59bd28603e..0e4547384e 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -24,6 +24,7 @@ class FusionOptions: self.enable_bias_skip_layer_norm = True self.enable_bias_gelu = True self.enable_gelu_approximation = False + self.enable_qordered_matmul = True self.enable_shape_inference = True diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_attention.py b/onnxruntime/python/tools/transformers/fusion_qordered_attention.py new file mode 100644 index 0000000000..b3d8743414 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_qordered_attention.py @@ -0,0 +1,421 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import Tuple + +import numpy as np +from fusion_attention import AttentionMask +from fusion_base import Fusion +from fusion_utils import FusionUtils, NumpyHelper +from onnx import NodeProto, helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionQOrderedAttention(Fusion): + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): + self.hidden_size = hidden_size + self.num_heads = num_heads + self.attention_mask = attention_mask + + super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization") + + def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]: + """Detect num_heads and hidden_size from a reshape node. + Args: + reshape_q (NodeProto): reshape node for Q + Returns: + Tuple[int, int]: num_heads and hidden_size + """ + + # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] + q_shape = self.model.get_initializer(reshape_q.input[1]) + if q_shape is None: + logger.debug(f"{reshape_q.input[1]} is not initializer.") + + # Check if the second input to Reshape flows through a Constant node + # TODO: Investigate why FusionAttention doesn't have such logic + constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1]) + + if constant_node is None: + return self.num_heads, self.hidden_size # Fall back to user specified value + else: + constant_node = constant_node[0] + + if len(constant_node.attribute) != 1: + return self.num_heads, self.hidden_size # Fall back to user specified value + + # This is assuming it is a Tensor attribute (this is a safe assumption) + q_shape = constant_node.attribute[0].t + + q_shape_value = NumpyHelper.to_array(q_shape) + if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0): + logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].") + return self.num_heads, self.hidden_size # Fall back to user specified value + + num_heads = q_shape_value[2] + head_size = q_shape_value[3] + hidden_size = num_heads * head_size + + if self.num_heads > 0 and num_heads != self.num_heads: + if self.num_heads_warning: + logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") + self.num_heads_warning = False # Do not show the warning more than once + + if self.hidden_size > 0 and hidden_size != self.hidden_size: + if self.hidden_size_warning: + logger.warning( + f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." + ) + self.hidden_size_warning = False # Do not show the warning more than once + + return num_heads, hidden_size + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + add_before_layernorm = self.model.match_parent_path( + normalize_node, + ["QuantizeLinear", "Add"], + [0, 0], + ) + + if add_before_layernorm is not None: + start_node = add_before_layernorm[-1] + else: + return + + # Input QDQ nodes + dequantize_input = self.model.match_parent_path( + start_node, + ["DequantizeLinear"], + [None], + ) + + if dequantize_input is None: + logger.debug("fuse_qordered_attention: failed to match input qdq nodes path") + return + + dequantize_input = dequantize_input[-1] + + # QKV nodes + qkv_nodes = self.model.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"], + [None, None, 0, 0, 0, 0, 0], + ) + + if qkv_nodes is None: + logger.debug("fuse_qordered_attention: failed to match qkv path") + return + + (_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes + + # Make sure the Q/DQ has the proper zero points and constant per-tensor scales + if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model): + return + + if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model): + return + + # Identify the root input to the Attention node + other_inputs = [] + for i, input in enumerate(start_node.input): + if input not in output_name_to_node: + continue + + if input == qkv_nodes[0].output[0]: + continue + + other_inputs.append(input) + + if len(other_inputs) != 1: + return + + root_input = other_inputs[0] + + # V nodes + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"], + [1, 0, 0, 0, 0, None], + ) + + if v_nodes is None: + logger.debug("fuse_qordered_attention: failed to match v path") + return + + (_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes + + # Make sure the Q/DQ has the proper zero points and constant per-tensor scales + if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model): + return + + if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model): + return + + # V MatMul weight + dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1]) + + if dequantize_v_matmul_weight is None: + logger.debug("fuse_qordered_attention: failed to match v path") + return + + dequantize_v_matmul_weight = dequantize_v_matmul_weight[0] + + if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None: + return + + # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales + # Per-channel scales are supported for weights alone + if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False): + return + + # QK nodes + qk_nodes = self.model.match_parent_path( + matmul_qkv, + [ + "DequantizeLinear", + "QuantizeLinear", + "Softmax", + "Add", + "Div", + "DequantizeLinear", + "QuantizeLinear", + "MatMul", + ], + [0, 0, 0, 0, None, 0, 0, 0], + ) + + if qk_nodes is None: + logger.debug("fuse_qordered_attention: failed to match qk path") + return + + ( + dequantize_qk_softmax, + quantize_qk_softmax, + softmax_qk, + add_qk, + div_qk, + dequantize_qk, + quantize_qk, + matmul_qk, + ) = qk_nodes + + # Make sure the Q/DQ has the proper zero points and constant per-tensor scales + if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model): + return + + if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model): + return + + if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model): + return + + if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model): + return + + # Q nodes + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"], + [0, 0, 0, 0, 0, None], + ) + + if q_nodes is None: + logger.debug("fuse_qordered_attention: failed to match q path") + return + + (_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes + + # Make sure the Q/DQ has the proper zero points and constant per-tensor scales + if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model): + return + + if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model): + return + + # Q MatMul weight + dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1]) + + if dequantize_q_matmul_weight is None: + logger.debug("fuse_qordered_attention: failed to match q path") + return + + dequantize_q_matmul_weight = dequantize_q_matmul_weight[0] + + if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None: + return + + # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales + # Per-channel scales are supported for weights alone + if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False): + return + + # K nodes + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"], + [1, 0, 0, 0, 0, None], + ) + + if k_nodes is None: + logger.debug("fuse_qordered_attention: failed to match k path") + return + + (_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes + + # Make sure the Q/DQ has the proper zero points and constant per-tensor scales + if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model): + return + + if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model): + return + + # K MatMul weight + dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1]) + + if dequantize_k_matmul_weight is None: + logger.debug("fuse_qordered_attention: failed to match k path") + return + + dequantize_k_matmul_weight = dequantize_k_matmul_weight[0] + + if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None: + return + + # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales + # Per-channel scales are supported for weights alone + if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False): + return + + # Mask nodes + mask_nodes = self.model.match_parent_path( + add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0] + ) + + if mask_nodes is None: + logger.debug("fuse_qordered_attention: failed to match mask_nodes path") + return + + # Ascertain `qkv_hidden_sizes` attribute value + q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0]) + k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0]) + v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0]) + + qw = NumpyHelper.to_array(q_weight) + kw = NumpyHelper.to_array(k_weight) + vw = NumpyHelper.to_array(v_weight) + + qw_out_size = np.prod(qw.shape[1:]) + kw_out_size = np.prod(kw.shape[1:]) + vw_out_size = np.prod(vw.shape[1:]) + + # Form QOrderedAttention node + if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input: + mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) + + # Ascertain `num_heads` and `hidden_size` + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + + # Formulate the inputs + # Actual quantized input + attention_inputs = [dequantize_input.input[0]] + attention_inputs.append(dequantize_input.input[1]) + + attention_inputs.append(dequantize_q.input[1]) + attention_inputs.append(dequantize_k.input[1]) + attention_inputs.append(dequantize_v.input[1]) + + attention_inputs.append(dequantize_q_matmul_weight.input[0]) + attention_inputs.append(dequantize_k_matmul_weight.input[0]) + attention_inputs.append(dequantize_v_matmul_weight.input[0]) + + attention_inputs.append(dequantize_q_matmul_weight.input[1]) + attention_inputs.append(dequantize_k_matmul_weight.input[1]) + attention_inputs.append(dequantize_v_matmul_weight.input[1]) + + if self.model.get_initializer(add_q.input[0]): + attention_inputs.append(add_q.input[0]) + else: # second input is the constant bias + attention_inputs.append(add_q.input[1]) + + if self.model.get_initializer(add_k.input[0]): + attention_inputs.append(add_k.input[0]) + else: # second input is the constant bias + attention_inputs.append(add_k.input[1]) + + if self.model.get_initializer(add_v.input[0]): + attention_inputs.append(add_v.input[0]) + else: # second input is the constant bias + attention_inputs.append(add_v.input[1]) + + attention_inputs.append(quantize_qk.input[1]) + attention_inputs.append(quantize_qk_softmax.input[1]) + attention_inputs.append(dequantize_qkv.input[1]) + + # Mask input + if mask_index is not None: + attention_inputs.append(mask_index) + else: + attention_inputs.append("") + + # The MatMul weight 'B' and 'bias' need some post-processing + # Transpose weight 'B' from order ROW to order COL + # This offline transpose is needed only while using the CUDA EP + # TODO: Make this fusion logic EP-agnostic ? + q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0]) + FusionUtils.transpose_2d_int8_tensor(q_weight_tensor) + + k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0]) + FusionUtils.transpose_2d_int8_tensor(k_weight_tensor) + + v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0]) + FusionUtils.transpose_2d_int8_tensor(v_weight_tensor) + + # Name and create Attention node + attention_node_name = self.model.create_node_name("QOrderedAttention") + + attention_node = helper.make_node( + "QOrderedAttention", + inputs=attention_inputs, + outputs=[reshape_qkv.output[0]], + name=attention_node_name, + ) + + self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0]) + self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0]) + + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + attention_node.attribute.extend([helper.make_attribute("order_input", 1)]) + attention_node.attribute.extend([helper.make_attribute("order_weight", 0)]) + attention_node.attribute.extend([helper.make_attribute("order_output", 1)]) + attention_node.attribute.extend( + [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])] + ) + + attention_node.domain = "com.microsoft" + + self.nodes_to_add.append(attention_node) + self.node_name_to_graph_name[attention_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + self.nodes_to_remove.extend(q_nodes) + self.nodes_to_remove.extend(k_nodes) + self.nodes_to_remove.extend(v_nodes) + self.nodes_to_remove.extend( + [dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight] + ) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + # self.nodes_to_remove.extend(mask_nodes) + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py b/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py new file mode 100644 index 0000000000..a92c8f94d4 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py @@ -0,0 +1,117 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import Dict + +from fusion_base import Fusion +from fusion_utils import FusionUtils +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionQOrderedGelu(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "QOrderedGelu", ["Gelu", "FastGelu"]) + + def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + """ + INPUT PATTERN + Fuse (quantized) Gelu subgraph into one node QOrderedGelu: + -> quantized input -> DQ -> Gelu -> Q -> + + (or) + + -> quantized input -> DQ -> FastGelu -> Q -> + + OUTPUT PATTERN + -> QOrderedGelu -> + """ + gelu_children = self.model.get_children(node, input_name_to_nodes) + + # Should only have 1 child - QuantizeLinear (or) + # Should have 2 children - QuantizeLinear + Shape + if not ( + (len(gelu_children) == 1 and gelu_children[0].op_type == "QuantizeLinear") + or ( + len(gelu_children) == 2 + and gelu_children[0].op_type == "QuantizeLinear" + and gelu_children[1].op_type == "Shape" + ) + ): + return + + downstream_quantize_node = gelu_children[0] + downstream_shape_node = None + + if len(gelu_children) == 2: + downstream_shape_node = gelu_children[1] + + if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model): + return + + # The first input to Gelu should flow through a DequantizeLinear node + first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths( + node, + [(["DequantizeLinear"], [0])], + output_name_to_node, + ) + + if first_path_id < 0: + return + + upstream_dequantize_node = first_input_parent_nodes[0] + + if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model): + return + + # Fusion logic + subgraph_nodes = [node] # Gelu/FastGelu + subgraph_nodes.extend([downstream_quantize_node, upstream_dequantize_node]) # Relevant Q, DQ nodes + + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + [node.output[0], downstream_quantize_node.output[0]] + if downstream_shape_node is not None + else downstream_quantize_node.output, + input_name_to_nodes, + output_name_to_node, + ): + logger.debug(f"It is not safe to fuse QOrderedGelu node. Skip") + return + + self.nodes_to_remove.extend(subgraph_nodes) + + ordered_gelu_node = helper.make_node( + "QOrderedGelu", + inputs=[ + upstream_dequantize_node.input[0], + upstream_dequantize_node.input[1], + downstream_quantize_node.input[1], + ], + outputs=[downstream_quantize_node.output[0]], + name=self.model.create_node_name("QOrderedGelu", name_prefix="QOrderedGelu"), + ) + + # Arrange the downstream Shape's input to be fed from the + # downstream QuantizeLinear node, so that fusion will + # be deemed safe + if downstream_shape_node is not None: + self.model.replace_node_input( + downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0] + ) + + # TODO: We only support CuBlasLt order ORDER_ROW for now. + # Once we start supporting other data ordering format(s), we + # will support user configuring the data ordering for the op. + ordered_gelu_node.attribute.extend([helper.make_attribute("order_X", 1)]) + ordered_gelu_node.attribute.extend([helper.make_attribute("order_Y", 1)]) + + ordered_gelu_node.domain = "com.microsoft" + + self.nodes_to_add.append(ordered_gelu_node) + self.node_name_to_graph_name[ordered_gelu_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py b/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py new file mode 100644 index 0000000000..f8198bcaa1 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Dict + +from fusion_base import Fusion +from fusion_utils import FusionUtils +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionQOrderedLayerNormalization(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "QOrderedLayerNormalization", "LayerNormalization") + + def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + """ + Fuse (quantized) Layer Normalization subgraph into one node QOrderedLayerNormalization: + quantized input -> DQ + | + | + (other inputs)-> LayerNormalization --> Q --> + + should become + + (quantized input + other inputs)-> QOrderedLayerNormalization --> Q --> + """ + + children = self.model.get_children(node, input_name_to_nodes) + + # Should only have 1 child - QuantizeLinear (or) + # Should have 2 children - QuantizeLinear + Shape + if not ( + (len(children) == 1 and children[0].op_type == "QuantizeLinear") + or (len(children) == 2 and children[0].op_type == "QuantizeLinear" and children[1].op_type == "Shape") + ): + return + + downstream_quantize_node = children[0] + downstream_shape_node = None + + if len(children) == 2: + downstream_shape_node = children[1] + + if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model): + return + + # The first input to LayerNormalization should flow through a DequantizeLinear node + first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths( + node, + [(["DequantizeLinear"], [0])], + output_name_to_node, + ) + + if first_path_id < 0: + return + + upstream_dequantize_node = first_input_parent_nodes[0] + + if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model): + return + + # Fusion logic + subgraph_nodes = [node] # LayerNormalization + subgraph_nodes.extend([downstream_quantize_node]) # Q node after LayerNormalization + + upstream_dequantize_node_children = self.model.get_children(upstream_dequantize_node, input_name_to_nodes) + + # In GPT2, the DQ node will be feeding a residual downstream Add and hence, + # we do not want to remove it + if len(upstream_dequantize_node_children) == 1: + subgraph_nodes.extend([upstream_dequantize_node]) # DQ node before LayerNormalization + + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + [node.output[0], downstream_quantize_node.output[0]] + if downstream_shape_node is not None + else downstream_quantize_node.output, + input_name_to_nodes, + output_name_to_node, + ): + logger.debug(f"It is not safe to fuse QOrderedLayerNormalization node. Skip") + return + + self.nodes_to_remove.extend(subgraph_nodes) + + normalize_node = helper.make_node( + "QOrderedLayerNormalization", + inputs=[ + upstream_dequantize_node.input[0], + upstream_dequantize_node.input[1], + node.input[1], + node.input[2], + downstream_quantize_node.input[1], + ], + outputs=[downstream_quantize_node.output[0]], + name=self.model.create_node_name("QOrderedLayerNormalization", name_prefix="QOrderedLayerNormalization"), + ) + + # Arrange the downstream Shape's input to be fed from the + # downstream QuantizeLinear node, so that fusion will + # be deemed safe + if downstream_shape_node is not None: + self.model.replace_node_input( + downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0] + ) + + # TODO: We only support CuBlasLt order ORDER_ROW for now. + # Once we start supporting other data ordering format(s), we + # will support user configuring the data ordering for the op. + normalize_node.attribute.extend([helper.make_attribute("order_X", 1)]) + normalize_node.attribute.extend([helper.make_attribute("order_Y", 1)]) + + normalize_node.domain = "com.microsoft" + + self.nodes_to_add.append(normalize_node) + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_matmul.py b/onnxruntime/python/tools/transformers/fusion_qordered_matmul.py new file mode 100644 index 0000000000..2fbd326268 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_qordered_matmul.py @@ -0,0 +1,217 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import Dict + +from fusion_base import Fusion +from fusion_utils import FusionUtils +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionQOrderedMatMul(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "QOrderedMatMul", "MatMul") + + def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + matmul_children = self.model.get_children(node, input_name_to_nodes) + + # Should only have 1 child - Bias Add + if len(matmul_children) != 1 or matmul_children[0].op_type != "Add": + return + + bias_add_node = matmul_children[0] + + # Atleast one of the inputs to Bias Add node must be a constant + bias_add_node_index = 0 + if ( + self.model.get_constant_value(bias_add_node.input[0]) is None + and self.model.get_constant_value(bias_add_node.input[1]) is None + ): + return + + if self.model.get_constant_value(bias_add_node.input[0]) is None: + bias_add_node_index = 1 + + bias_add_children = self.model.get_children(bias_add_node, input_name_to_nodes) + + if len(bias_add_children) != 1: + return + + bias_add_child = bias_add_children[0] + + # Bias Add can have another Add downstream (Residual Add layer) + residual_add_node = None + + downstream_quantize_node = None + + if bias_add_child.op_type == "Add": + residual_add_node = bias_add_child + + residual_add_children = self.model.get_children(residual_add_node, input_name_to_nodes) + + if len(residual_add_children) != 1 or residual_add_children[0].op_type != "QuantizeLinear": + return + + downstream_quantize_node = residual_add_children[0] + + elif bias_add_child.op_type == "QuantizeLinear": + downstream_quantize_node = bias_add_child + + else: + return + + # Make sure the downstream QuantizeLinear has the proper zero points and scales + if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model): + return + + # The first input to MatMul should flow through a DequantizeLinear node + first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths( + node, + [(["DequantizeLinear"], [0])], + output_name_to_node, + ) + + # If Attention is not fused, this is the pattern to look for + # leading upto the MatMul + reshape_node_0 = None + transpose_node_0 = None + if first_path_id < 0: + first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths( + node, + [(["Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear"], [0, 0, 0, 0])], + output_name_to_node, + ) + + if first_path_id < 0: + return + + reshape_node_0 = first_input_parent_nodes[0] + transpose_node_0 = first_input_parent_nodes[1] + dequantize_node_0 = first_input_parent_nodes[2] + else: + dequantize_node_0 = first_input_parent_nodes[0] + + # Make sure the upstream DequantizeLinear-0 has the proper zero points and scales + if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_0, self.model): + return + + # The second input to MatMul should flow through a DequantizeLinear node + dequantize_node_1 = None + is_weight_transpose_required = True + + weight_path_id, weight_nodes, _ = self.model.match_parent_paths( + node, + [(["DequantizeLinear", "QuantizeLinear", "Transpose", "DequantizeLinear"], [1, 0, 0, 0])], + output_name_to_node, + ) + + if weight_path_id < 0: + weight_path_id, weight_nodes, _ = self.model.match_parent_paths( + node, + [(["DequantizeLinear"], [1])], + output_name_to_node, + ) + + if weight_path_id < 0: + return + + dequantize_node_1 = weight_nodes[0] + else: + is_weight_transpose_required = False + dequantize_node_1 = weight_nodes[3] + + # Check if weight 'B' is a constant + if self.model.get_constant_value(dequantize_node_1.input[0]) is None: + return + + # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales + # Per-channel scales are supported for weights alone + if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_1, self.model, False): + return + + # Make sure the upstream flow into the Residual Add node flows through a DQ node + residual_add_dequantize_node = None + + if residual_add_node is not None: + residual_path_id, residual_input_parent_nodes, _ = self.model.match_parent_paths( + residual_add_node, + [ + (["DequantizeLinear"], [1]), + ], + output_name_to_node, + ) + + if residual_path_id < 0: + return + + residual_add_dequantize_node = residual_input_parent_nodes[0] + + # Make sure the upstream DequantizeLinear to the Residual Add has the proper zero points and scales + if residual_add_dequantize_node is not None and not FusionUtils.check_qdq_node_for_fusion( + residual_add_dequantize_node, self.model + ): + return + + # Subgraph nodes to be fused + subgraph_nodes = [node, bias_add_node] # MatMul + Bias Add + + if residual_add_node is not None: + subgraph_nodes.extend([residual_add_node]) # Residual Add + + subgraph_nodes.extend(weight_nodes) + subgraph_nodes.extend([downstream_quantize_node]) # Downstream Q node + + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, downstream_quantize_node.output, input_name_to_nodes, output_name_to_node + ): + logger.debug(f"It is not safe to fuse QOrderedMatMul node. Skip") + return + + # Deal with the case where-in the Attention subgraph is not fused + if transpose_node_0 is not None: + self.model.replace_node_input(transpose_node_0, transpose_node_0.input[0], dequantize_node_0.input[0]) + + # Make inputs + fused_node_inputs = [ + reshape_node_0.output[0] if reshape_node_0 is not None else dequantize_node_0.input[0], + dequantize_node_0.input[1], + dequantize_node_1.input[0], + dequantize_node_1.input[1], + downstream_quantize_node.input[1], + bias_add_node.input[bias_add_node_index], + ] + + if residual_add_node is not None: + fused_node_inputs.append(residual_add_dequantize_node.input[0]) + fused_node_inputs.append(residual_add_dequantize_node.input[1]) + + # The MatMul weight 'B' and 'bias' need some post-processing + # Transpose weight 'B' from order ROW to order COL + # This offline transpose is needed only while using the CUDA EP + # TODO: Make this fusion logic EP-agnostic ? + if is_weight_transpose_required: + weight_tensor = self.model.get_initializer(dequantize_node_1.input[0]) + FusionUtils.transpose_2d_int8_tensor(weight_tensor) + + fused_node = helper.make_node( + "QOrderedMatMul", + inputs=fused_node_inputs, + outputs=[downstream_quantize_node.output[0]], + name=self.model.create_node_name("QOrderedMatMul", name_prefix="QOrderedMatMul"), + ) + + fused_node.attribute.extend([helper.make_attribute("order_A", 1)]) + fused_node.attribute.extend([helper.make_attribute("order_B", 0)]) + fused_node.attribute.extend([helper.make_attribute("order_Y", 1)]) + + fused_node.domain = "com.microsoft" + + self.nodes_to_remove.extend(subgraph_nodes) + self.nodes_to_add.append(fused_node) + self.node_name_to_graph_name[fused_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index a3ae4fb60e..865c1542c1 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -5,8 +5,10 @@ from logging import getLogger from typing import Tuple +import numpy from numpy import array_equal, ndarray -from onnx import TensorProto, helper, numpy_helper +from onnx import NodeProto, TensorProto, helper, numpy_helper +from onnx import onnx_pb as onnx_proto from onnx_model import OnnxModel logger = getLogger(__name__) @@ -83,6 +85,73 @@ class FusionUtils: else: return value == expected_value + @staticmethod + def transpose_2d_int8_tensor(tensor: onnx_proto.TensorProto): + """Transpose a 2-D INT8 TensorProto + Args: + tensor (TensorProto): tensor to be transposed + Returns: + tensor (TensorProto): transposed tensor + """ + if not isinstance(tensor, onnx_proto.TensorProto): + raise ValueError("Expected input type is an ONNX TensorProto but got %s" % type(tensor)) + + if len(tensor.dims) != 2 or tensor.data_type != onnx_proto.TensorProto.INT8: + raise ValueError("Only INT8 2-D tensors can be transposed") + + if tensor.raw_data: + int32_data = numpy.reshape(numpy.frombuffer(tensor.raw_data, dtype="int8"), tensor.dims) + int32_transposed_data = numpy.transpose(int32_data, [1, 0]) + tensor.raw_data = int32_transposed_data.tobytes() + + else: + raise ValueError("only raw buffer supported") + + return tensor + + @staticmethod + def check_qdq_node_for_fusion(node: NodeProto, model: OnnxModel, allow_per_tensor_quantization_only=True): + """Verify if a provided QuantizeLinear (Q) / DequantizeLinear (DQ) node is a good candidate for fusion. + It is a good candidate for fusion if: + (1) The Q/DQ node is for per-tensor quantization if allow_per_tensor_quantization_only is `True` + (2) The Q/DQ node should have constant scale + (3) The Q/DQ node should have a zero point of 0 + Args: + node (NodeProto): a Q/DQ node to check + Returns: + bool: whether the check is passed or not + """ + if not node.op_type in {"QuantizeLinear", "DequantizeLinear"}: + logger.debug(f"Provided node is not a Q/DQ node. Op Type: {node.op_type}") + + scale = model.get_constant_value(node.input[1]) + + # Scale is not constant + if scale is None: + return False + + # Not per-tensor quantization + scale_has_single_element = scale.ndim == 0 or (scale.ndim == 1 and scale.shape[0] == 1) + if allow_per_tensor_quantization_only and not scale_has_single_element: + return False + + # If the Q/DQ node has no zero point input, it is assumed to be 0 (per ONNX spec) + if len(node.input) == 2: + return True + + # Zero point should be constant and should have a value of 0 + zero_point = model.get_constant_value(node.input[2]) + + # Zero point and scale should have same number of dims + if scale.ndim != zero_point.ndim: + return False + + # Zero point is not constant or zero point is not zero + if zero_point is None: + return False + + return numpy.all(zero_point == 0) + def check_node_input_value(self, node, input_index: int, expected_value): """Verify that a node has expected input value diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 72ee1211e4..3461b9d618 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -14,6 +14,10 @@ from fusion_gelu import FusionGelu from fusion_gelu_approximation import FusionGeluApproximation from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF from fusion_options import FusionOptions +from fusion_qordered_attention import FusionQOrderedAttention +from fusion_qordered_gelu import FusionQOrderedGelu +from fusion_qordered_layernorm import FusionQOrderedLayerNormalization +from fusion_qordered_matmul import FusionQOrderedMatMul from fusion_reshape import FusionReshape from fusion_shape import FusionShape from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization @@ -49,16 +53,24 @@ class BertOnnxModel(OnnxModel): self.attention_mask = AttentionMask(self) self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + self.qordered_attention_fusion = FusionQOrderedAttention( + self, self.hidden_size, self.num_heads, self.attention_mask + ) self.utils = FusionUtils(self) def fuse_attention(self): self.attention_fusion.apply() + # Only relevant in models with Q-DQ nodes + self.qordered_attention_fusion.apply() def fuse_gelu(self): fusion = FusionGelu(self) fusion.apply() fusion = FusionFastGelu(self) fusion.apply() + # Only relevant in models with Q-DQ nodes + fusion = FusionQOrderedGelu(self) + fusion.apply() def fuse_bias_gelu(self, is_fastgelu): fusion = FusionBiasGelu(self, is_fastgelu) @@ -91,10 +103,19 @@ class BertOnnxModel(OnnxModel): fusion = FusionLayerNormalizationTF(self) fusion.apply() + # Only relevant in models with Q-DQ nodes + fusion = FusionQOrderedLayerNormalization(self) + fusion.apply() + def fuse_skip_layer_norm(self): fusion = FusionSkipLayerNormalization(self) fusion.apply() + # Only relevant in models with Q-DQ nodes + def fuse_qordered_mamtul(self): + fusion = FusionQOrderedMatMul(self) + fusion.apply() + def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool): """ Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention). @@ -364,6 +385,11 @@ class BertOnnxModel(OnnxModel): self.attention_mask.set_mask_format(options.attention_mask_format) self.fuse_attention() + # Perform the MatMul fusion after the Attention fusion as we do not + # want to fuse the MatMuls inside the Attention subgraphs + if (options is None) or options.enable_qordered_matmul: + self.fuse_qordered_mamtul() + self.fuse_shape() if (options is None) or options.enable_embed_layer_norm: @@ -403,11 +429,15 @@ class BertOnnxModel(OnnxModel): ops = [ "EmbedLayerNormalization", "Attention", + "QOrderedAttention", "Gelu", + "QOrderedGelu", "FastGelu", "BiasGelu", "LayerNormalization", + "QOrderedLayerNormalization", "SkipLayerNormalization", + "QOrderedMatMul", ] for op in ops: nodes = self.get_nodes_by_op_type(op) @@ -421,7 +451,7 @@ class BertOnnxModel(OnnxModel): """ op_count = self.get_fused_operator_statistics() embed = op_count["EmbedLayerNormalization"] - attention = op_count["Attention"] + attention = op_count["Attention"] + op_count["QOrderedAttention"] gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"] is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention)