From aaf32fb1b1d1cb3a3a4250cc2ee132576b7ddf00 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Mon, 5 Feb 2024 18:15:16 +0000 Subject: [PATCH] phi2 conversion/optimization script (#19338) ### Description This PR adds onnx conversion script for dynamo exported phi2, optimization script, and inference example script A readme file is added as documentation. https://github.com/microsoft/onnxruntime/tree/wangye/phi2_doc/onnxruntime/python/tools/transformers/models/phi2#readme ### Motivation and Context --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- cmake/onnxruntime_python.cmake | 7 + .../python/tools/symbolic_shape_infer.py | 28 + .../tools/transformers/dynamo_onnx_helper.py | 92 ++ .../python/tools/transformers/float16.py | 8 +- .../tools/transformers/fusion_options.py | 15 + .../tools/transformers/models/phi2/README.md | 119 +++ .../transformers/models/phi2/__init__.py | 12 + .../models/phi2/convert_to_onnx.py | 458 ++++++++++ .../models/phi2/inference_example.py | 215 +++++ .../transformers/models/phi2/requirements.txt | 3 + .../python/tools/transformers/onnx_model.py | 5 + .../tools/transformers/onnx_model_phi.py | 839 ++++++++++++++++++ .../python/tools/transformers/optimizer.py | 2 + setup.py | 1 + 14 files changed, 1801 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/dynamo_onnx_helper.py create mode 100644 onnxruntime/python/tools/transformers/models/phi2/README.md create mode 100644 onnxruntime/python/tools/transformers/models/phi2/__init__.py create mode 100644 onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py create mode 100644 onnxruntime/python/tools/transformers/models/phi2/inference_example.py create mode 100644 onnxruntime/python/tools/transformers/models/phi2/requirements.txt create mode 100644 onnxruntime/python/tools/transformers/onnx_model_phi.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 456344aa34..3f20787e87 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -473,6 +473,9 @@ file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py" ) +file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py" +) file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py" ) @@ -543,6 +546,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/gpt2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/llama COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/longformer + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/phi2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/whisper @@ -646,6 +650,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_longformer_src} $/onnxruntime/transformers/models/longformer/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_phi2_src} + $/onnxruntime/transformers/models/phi2/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_stable_diffusion_src} $/onnxruntime/transformers/models/stable_diffusion/ diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9823e8264e..31f3a3a2b3 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -205,6 +205,7 @@ class SymbolicShapeInference: "GemmFastGelu": self._infer_GemmFastGelu, "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, + "GroupQueryAttention": self._infer_GroupQueryAttention, "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, @@ -471,6 +472,7 @@ class SymbolicShapeInference: "PythonOp", "MultiHeadAttention", "GroupNorm", + "GroupQueryAttention", "SkipGroupNorm", "BiasSplitGelu", "BiasAdd", @@ -2409,6 +2411,32 @@ class SymbolicShapeInference: def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_GroupQueryAttention(self, node): # noqa: N802 + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + past_shape = self._try_get_shape(node, 3) + if past_shape is not None: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + + if node.input[1] != "" and node.input[2] != "": + self._propagate_shape_and_type(node, 0, 0) + else: + # combined qkv: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size) + assert node.input[1] == "" and node.input[2] == "" + num_heads = get_attribute(node, "num_heads") + kv_num_heads = get_attribute(node, "kv_num_heads") + query_shape = self._get_shape(node, 0) + if query_shape is not None: + hidden_size = query_shape[2] + if isinstance(hidden_size, int): + head_size = int(hidden_size / (num_heads + 2 * kv_num_heads)) + query_shape[2] = num_heads * head_size + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape)) + def _infer_SkipGroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node, 0, 0) if len(node.output) > 1: diff --git a/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py new file mode 100644 index 0000000000..bca5ace916 --- /dev/null +++ b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +import onnx + + +class DynamoOnnxHelper: + """ + Helper class for processing ONNX models exported by torch Dynamo. + """ + + def __init__(self, model: onnx.ModelProto): + self.model = model + + def update_edges(self, edge_mapping: dict) -> None: + """ + Updates the edges in the model according to the given mapping. + """ + for node in self.model.graph.node: + for i in range(len(node.input)): + if node.input[i] in edge_mapping: + node.input[i] = edge_mapping[node.input[i]] + for i in range(len(node.output)): + if node.output[i] in edge_mapping: + node.output[i] = edge_mapping[node.output[i]] + + for graph_input in self.model.graph.input: + if graph_input.name in edge_mapping: + graph_input.name = edge_mapping[graph_input.name] + for graph_output in self.model.graph.output: + if graph_output.name in edge_mapping: + graph_output.name = edge_mapping[graph_output.name] + + def unroll_function(self, func_name: str) -> None: + """ + Unrolls the function with the given name in the model. + """ + logging.info(f"Unrolling function {func_name}...") + nodes_to_remove = [] + nodes_to_add = [] + edges_to_remove = [] + edges_to_add = [] + for node in self.model.graph.node: + if node.op_type == func_name: + nodes_to_remove.append(node) + edges_to_remove.extend(list(node.input) + list(node.output)) + + func_to_remove = None + for f in self.model.functions: + if f.name == func_name: + nodes_to_add.extend(list(f.node)) + edges_to_add.extend(list(f.input) + list(f.output)) + func_to_remove = f + + assert len(edges_to_remove) == len(edges_to_add) + + for node in nodes_to_remove: + self.model.graph.node.remove(node) + for node in nodes_to_add: + self.model.graph.node.append(node) + if func_to_remove is not None: + self.model.functions.remove(func_to_remove) + + edge_mapping = {} + for i in range(len(edges_to_remove)): + k = edges_to_remove[i] + v = edges_to_add[i] + if k != v: + edge_mapping[k] = v + + return self.update_edges(edge_mapping) + + def remove_dropout_layer(self) -> None: + """ + Removes the dropout layer in the model. + """ + logging.info("Removing dropout layer...") + edge_mapping = {} + nodes_to_remove = [] + for node in self.model.graph.node: + if node.op_type.find("Dropout") != -1: + assert len(node.input) == 1 + assert len(node.output) == 1 + edge_mapping[node.output[0]] = node.input[0] + nodes_to_remove.append(node) + for node in nodes_to_remove: + self.model.graph.node.remove(node) + + self.update_edges(edge_mapping) diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index f680a15fc2..48c79b1d5f 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -174,6 +174,7 @@ def convert_float_to_float16( node_block_list=None, force_fp16_initializers=False, force_fp16_inputs=None, + use_bfloat16_as_blocked_nodes_dtype=False, ): """Convert tensor float type in the input ONNX model to tensor float16. @@ -436,6 +437,7 @@ def convert_float_to_float16( node.input[i] = output_name break + accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT # process the nodes in block list that doesn't support tensor(float16) for node in node_list: # if input's name is in the value_info_list meaning input is tensor(float16) type, @@ -450,10 +452,10 @@ def convert_float_to_float16( new_value_info.CopyFrom(value_info) output_name = node.name + "_input_cast_" + str(i) new_value_info.name = output_name - new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + new_value_info.type.tensor_type.elem_type = accuracy_type # add Cast node (from tensor(float16) to tensor(float) before current node node_name = node.name + "_input_cast" + str(i) - new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] + new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)] model.graph.node.extend(new_node) # change current node's input name node.input[i] = output_name @@ -469,7 +471,7 @@ def convert_float_to_float16( new_value_info.CopyFrom(value_info) input_name = node.name + "_output_cast_" + str(i) new_value_info.name = input_name - new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + new_value_info.type.tensor_type.elem_type = accuracy_type # add Cast node (from tensor(float) to tensor(float16) after current node node_name = node.name + "_output_cast" + str(i) new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)] diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index b9b92d2fe8..c65464a306 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from argparse import ArgumentParser +from enum import Enum class AttentionMaskFormat: @@ -19,6 +20,15 @@ class AttentionMaskFormat: NoMask = 3 +class AttentionOpType(Enum): + Attention = "Attention" + MultiHeadAttention = "MultiHeadAttention" + GroupQueryAttention = "GroupQueryAttention" + + def __str__(self): + return self.value + + class FusionOptions: """Options of fusion in graph optimization""" @@ -57,6 +67,8 @@ class FusionOptions: elif model_type == "vit": self.attention_mask_format = AttentionMaskFormat.NoMask + self.attention_op_type = None + # options for stable diffusion if model_type in ["unet", "vae", "clip"]: self.enable_nhwc_conv = True @@ -76,6 +88,9 @@ class FusionOptions: def disable_attention_mask(self): self.attention_mask_format = AttentionMaskFormat.NoMask + def set_attention_op_type(self, attn_op_type: AttentionOpType): + self.attention_op_type = attn_op_type + @staticmethod def parse(args): options = FusionOptions(args.model_type) diff --git a/onnxruntime/python/tools/transformers/models/phi2/README.md b/onnxruntime/python/tools/transformers/models/phi2/README.md new file mode 100644 index 0000000000..526fdc3dd7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/README.md @@ -0,0 +1,119 @@ +# Phi2 Optimizations +## Prerequisites +A Linux machine for [TorchDynamo-based ONNX Exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter)\ +Install onnx, onnxscript and transformers by running +```bash +pip install -r requirements.txt +``` +To export ONNX, PyTorch version 2.2.0 or higher is required. The [official website](https://pytorch.org/) offers packages compatible with CUDA 11.8 and 12.1. Please select the appropriate version according to your needs. +\ +\ +**There are two options to run the conversion script:**\ +_From source:_ +```bash +pip install onnxruntime-gpu==1.17.0 # or onnxruntime==1.17.0 if using cpu +git clone git@github.com:microsoft/onnxruntime.git +cd onnxruntime/onnxruntime/python/tools/transformers +python -m models.phi2.convert_to_onnx -h +``` +_From wheel:_ \ +Install [ORT nightly package](https://onnxruntime.ai/docs/install/#inference-install-table-for-all-languages) +```bash +python -m onnxruntime.transformers.models.phi2.convert_to_onnx -h +``` + +## Export optimized phi2 onnx model for different scenarios +**Export FP32 ONNX model for Nvidia GPUs** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp32_gpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_gpu +``` +\ +**Export FP16 ONNX model for Nvidia GPUs** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp16_gpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu +``` +\ +**Export INT4 ONNX model for Nvidia GPUs** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --int4_gpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu +``` +\ +**Export FP16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x +``` +\ +**Export INT4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --int4_gpu_sm8x +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu_sm8x +``` +\ +**Export FP32 ONNX model for CPU** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp32_cpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu +``` +\ +**Export INT4 ONNX model for CPU** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --int4_cpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_cpu +``` +\ +**Export all at once** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x +``` +## Run example with ORT +**(e.g) Export FP16 and INT4 ONNX models for Nvidia GPUs with CUDA architecture SM=80~89 and run examples.** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example +``` +The inference example currently supports all models running on CUDA. + +## Limitations +- TorchDynamo-based ONNX Exporter only supports Linux. +- The program may not run as expected if the machine has limited memory. e.g Dynamo export may use ~11.6GB; Optimization may use ~4.5GB for each. diff --git a/onnxruntime/python/tools/transformers/models/phi2/__init__.py b/onnxruntime/python/tools/transformers/models/phi2/__init__.py new file mode 100644 index 0000000000..e80f36a391 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +import sys + +sys.path.append(os.path.dirname(__file__)) + +transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) +if transformers_dir not in sys.path: + sys.path.append(transformers_dir) diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py new file mode 100644 index 0000000000..ac3ca40e41 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -0,0 +1,458 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import argparse +import logging +import os +from pathlib import Path + +import onnx +import torch +from benchmark_helper import Precision +from fusion_options import AttentionOpType +from transformers import AutoConfig, AutoModelForCausalLM + +from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer + + +class ConvertPhi2ToONNX: + def __init__( + self, + device: torch.device, + model_class: str = "microsoft/phi-2", + cache_dir: str = "./cache", + ): + self.model_class = model_class + self.device = device + self.cache_dir = cache_dir + self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir) + self.phi_model = None + self.batch_size = 2 + self.sequence_length = 8 + self.attn_op_type = None + self.precision = None + self.block_size = 16 + self.accuracy_level = None + + def set_quantization_params(self, block_size: int, accuracy_level: int | None): + self.block_size = block_size + self.accuracy_level = accuracy_level + + def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision): + self.attn_op_type = attn_op_type + self.precision = precision + + def erase_onnx_model(self, onnx_path: str) -> None: + assert onnx_path.endswith(".onnx") + if not os.path.exists(onnx_path): + return + + model = onnx.load_model(onnx_path, load_external_data=False) + onnx_data_path = None + for initializer in model.graph.initializer: + if initializer.data_location == 1 and initializer.external_data[0].key == "location": + onnx_data_path = "./" + initializer.external_data[0].value + break + logging.info(f"Erasing {onnx_path}...") + os.remove(onnx_path) + if onnx_data_path is not None: + onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path) + logging.info(f"Erasing {onnx_data_path}...") + os.remove(onnx_data_path) + + def get_phi2_torch_model(self): + logging.info("Loading phi2 torch model...") + if self.phi_model is not None: + return + self.phi_model = AutoModelForCausalLM.from_pretrained( + self.model_class, trust_remote_code=True, cache_dir=self.cache_dir + ) + self.phi_model.eval() + self.phi_model.to(self.device) + + def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int): + input_ids = torch.randint( + low=0, + high=self.phi_config.vocab_size, + size=(batch_size, sequence_length), + dtype=torch.int64, + device=self.device, + ) + self.get_phi2_torch_model() + torch_inputs = self.phi_model.prepare_inputs_for_generation( + input_ids, past_key_values=self.phi_model(input_ids, use_cache=True)["past_key_values"] + ) + return torch_inputs["input_ids"], torch_inputs["attention_mask"], torch_inputs["past_key_values"] + + def dynamo_export(self, onnx_path: str): + input_ids, attention_mask, past_key_values = self.get_phi2_torch_inputs(self.batch_size, self.sequence_length) + self.phi_model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values) + + from torch._dynamo import config + + config.capture_scalar_outputs = True + + logging.info("Exporting Phi2 torch model to ONNX...") + torch.onnx.dynamo_export( + self.phi_model, + input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + export_options=torch.onnx.ExportOptions(dynamic_shapes=True), + ).save(onnx_path) + onnx.checker.check_model(onnx_path) + onnx.shape_inference.infer_shapes_path(onnx_path) + + def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): + from fusion_options import FusionOptions + from optimizer import optimize_model + + optimization_options = FusionOptions("phi") + optimization_options.set_attention_op_type(self.attn_op_type) + optimizer = optimize_model( + onnx_path, + model_type="phi", + num_heads=self.phi_config.num_attention_heads, + hidden_size=self.phi_config.hidden_size, + opt_level=0, + optimization_options=optimization_options, + only_onnxruntime=False, + ) + + fused_op_count = optimizer.get_fused_operator_statistics() + if optimizer.is_fully_optimized(fused_op_count): + logging.info("Model is fully optimized.") + else: + logging.info("Model is not fully optimized.") + + if self.precision == Precision.FLOAT32: + optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True) + return + + if ( + self.precision == Precision.FLOAT16 or self.precision == Precision.INT4 + ) and self.attn_op_type != AttentionOpType.MultiHeadAttention: + # We keep last three layers of Attention as float32 or bfloat16 to avoid overflow. + node_block_list = [ + "GroupQueryAttention_29", + "GroupQueryAttention_30", + "GroupQueryAttention_31", + "Attention_29", + "Attention_30", + "Attention_31", + ] + logging.info("Converting onnx model to float16/bfloat16...") + optimizer.convert_float_to_float16( + keep_io_types=False, + node_block_list=node_block_list, + use_symbolic_shape_infer=True, + use_bfloat16_as_blocked_nodes_dtype=self.attn_op_type == AttentionOpType.GroupQueryAttention, + ) + logging.info("Converting onnx model to float16/bfloat16 done.") + + if self.precision == Precision.FLOAT16: + optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True) + return + else: + assert self.precision == Precision.INT4 + quant = MatMul4BitsQuantizer( + model=optimizer.model, + block_size=self.block_size, + is_symmetric=True, + accuracy_level=self.accuracy_level, + ) + quant.process() + quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--fp32_cpu", + required=False, + action="store_true", + help="Generate fp32 ONNX model for CPU", + ) + + parser.add_argument( + "--int4_cpu", + required=False, + action="store_true", + help="Generate int4 ONNX model for CPU", + ) + + parser.add_argument( + "--fp32_gpu", + required=False, + action="store_true", + help="Generate fp32 ONNX model for Nvidia GPUs", + ) + + parser.add_argument( + "--fp16_gpu", + required=False, + action="store_true", + help="Generate fp16 ONNX model for Nvidia GPUs", + ) + + parser.add_argument( + "--int4_gpu", + required=False, + action="store_true", + help="Generate int4 ONNX model for Nvidia GPUs", + ) + + parser.add_argument( + "--fp16_gpu_sm8x", + required=False, + action="store_true", + help="Generate fp16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89", + ) + + parser.add_argument( + "--int4_gpu_sm8x", + required=False, + action="store_true", + help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89", + ) + + parser.add_argument( + "--overwrite", + required=False, + action="store_true", + help="Overwrite existing ONNX models", + ) + + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./cache", + help="The cache directory for the pytorch model", + ) + + parser.add_argument( + "--device_id", + required=False, + type=int, + default=0, + help="The device id for the pytorch model", + ) + + parser.add_argument( + "--run_example", + required=False, + action="store_true", + help="Run ORT inference example", + ) + + parser.add_argument( + "--skip_export", + required=False, + action="store_true", + help="Skip exporting ONNX model", + ) + + parser.add_argument( + "--output_dir", + type=str, + help="The output directory for the ONNX models", + default="phi2_onnx_models", + ) + + parser.add_argument( + "--block_size", + required=False, + default=16, + type=int, + help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.", + ) + + parser.add_argument( + "--int4_accuracy_level", + required=False, + type=int, + help="Accuracy level of the 4-bit quantized MatMul computation. " + "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " + "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_arguments() + + device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu") + + converter = ConvertPhi2ToONNX(device, cache_dir=args.cache_dir) + converter.set_quantization_params(args.block_size, args.int4_accuracy_level) + + output_dir = args.output_dir + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + original_onnx_path = os.path.join(output_dir, "phi2_original.onnx") + + if not args.skip_export: + if not os.path.exists(original_onnx_path) or args.overwrite: + converter.dynamo_export(original_onnx_path) + + model_type_to_args = { + "fp32_cpu": ( + AttentionOpType.MultiHeadAttention, + Precision.FLOAT32, + os.path.join(output_dir, "phi2_decoder_fp32_cpu.onnx"), + ), + "int4_cpu": ( + AttentionOpType.MultiHeadAttention, + Precision.INT4, + os.path.join(output_dir, "phi2_decoder_int4_cpu.onnx"), + ), + "fp32_gpu": ( + AttentionOpType.Attention, + Precision.FLOAT32, + os.path.join(output_dir, "phi2_decoder_fp32_gpu.onnx"), + ), + "fp16_gpu": ( + AttentionOpType.Attention, + Precision.FLOAT16, + os.path.join(output_dir, "phi2_decoder_fp16_gpu.onnx"), + ), + "int4_gpu": (AttentionOpType.Attention, Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu.onnx")), + "fp16_gpu_sm8x": ( + AttentionOpType.GroupQueryAttention, + Precision.FLOAT16, + os.path.join(output_dir, "phi2_decoder_fp16_gpu_sm8x.onnx"), + ), + "int4_gpu_sm8x": ( + AttentionOpType.GroupQueryAttention, + Precision.INT4, + os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"), + ), + } + + if not args.skip_export: + from multiprocessing import Process + + def run_optimize_phi2_onnx( + converter: ConvertPhi2ToONNX, + original_onnx_path: str, + attention_type: AttentionOpType, + precision: Precision, + optimized_onnx_path: str, + ): + converter.init_attn_type_and_precision(attention_type, precision) + converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path) + + processes = [] + if args.fp32_cpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_cpu"]) + ) + ) + + if args.int4_cpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_cpu"]) + ) + ) + + if args.fp32_gpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_gpu"]) + ) + ) + + if args.fp16_gpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu"]) + ) + ) + + if args.int4_gpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_gpu"]) + ) + ) + + if args.fp16_gpu_sm8x: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu_sm8x"]), + ) + ) + + if args.int4_gpu_sm8x: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["int4_gpu_sm8x"]), + ) + ) + + [p.start() for p in processes] + [p.join() for p in processes] + + if args.run_example: + from inference_example import run_phi2 + + if args.fp16_gpu_sm8x: + logging.info("Running fp16_gpu_sm8x example...") + run_phi2( + onnx_model_path=model_type_to_args["fp16_gpu_sm8x"][2], + use_buffer_share=True, + device_id=args.device_id, + use_step=True, + ) + if args.int4_gpu_sm8x: + logging.info("Running int4_gpu_sm8x example...") + run_phi2( + onnx_model_path=model_type_to_args["int4_gpu_sm8x"][2], + use_buffer_share=True, + device_id=args.device_id, + use_step=True, + ) + if args.fp32_gpu: + logging.info("Running fp32_gpu example...") + run_phi2( + onnx_model_path=model_type_to_args["fp32_gpu"][2], + use_buffer_share=False, + device_id=args.device_id, + packed_kv=True, + use_fp16=False, + ) + if args.fp16_gpu: + logging.info("Running fp16_gpu example...") + run_phi2( + onnx_model_path=model_type_to_args["fp16_gpu"][2], + use_buffer_share=False, + device_id=args.device_id, + packed_kv=True, + ) + if args.int4_gpu: + logging.info("Running int4_gpu example...") + run_phi2( + onnx_model_path=model_type_to_args["int4_gpu"][2], + use_buffer_share=False, + device_id=args.device_id, + packed_kv=True, + ) + if args.fp32_cpu or args.int4_cpu: + raise NotImplementedError("CPU inference example is not implemented yet.") + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py new file mode 100644 index 0000000000..28828ffb85 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py @@ -0,0 +1,215 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import numpy as np +import torch +from transformers import AutoTokenizer + +import onnxruntime as ort + +pt_to_np = { + "torch.int32": np.int32, + "torch.int64": np.int64, + "torch.float32": np.float32, + "torch.float16": np.float16, +} + + +class ORTGenerator: + def __init__(self, decoder_path): + self.onnx_decoder_path = decoder_path + self.num_heads = 32 + self.head_size = 80 + self.num_layers = 32 + self.max_sequence_length = 2048 + + def get_initial_inputs_and_outputs(self, encodings_dict): + self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32 + + input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32) + attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32) + step = torch.tensor([0], device=self.device, dtype=torch.int64) + + inputs = { + "input_ids": input_ids.contiguous(), + "attention_mask": attention_mask.contiguous(), + } + + if self.use_step: + inputs["step"] = step.contiguous() + + batch_size, sequence_length = input_ids.shape + + past_seq_length = self.max_sequence_length if self.use_buffer_share else 0 + past_shape = ( + (2, batch_size, self.num_heads, past_seq_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, past_seq_length, self.head_size) + ) + for i in range(self.num_layers): + past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype) + inputs.update( + {f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()} + ) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()}) + + logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype) + outputs = {"logits": logits.contiguous()} + + if not self.use_buffer_share: + present_shape = ( + (2, batch_size, self.num_heads, sequence_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, sequence_length, self.head_size) + ) + for i in range(self.num_layers): + present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) + outputs.update( + {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()} + ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + + return inputs, outputs + + def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict): + io_binding = model.io_binding() + device = None + + for k, v in inputs.items(): + io_binding.bind_input( + name=k, + device_type=v.device.type, + device_id=0 if v.device.type == "cpu" else v.device.index, + element_type=pt_to_np[repr(v.dtype)], + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + device = v.device + + for output in model.get_outputs(): + name = output.name + if self.use_buffer_share and "present" in name: + v = inputs[name.replace("present", "past")] + io_binding.bind_output( + name=name, + device_type=v.device.type, + device_id=v.device.index, + element_type=(np.float16 if self.use_fp16 else np.float32), + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + else: + v = outputs[name] + io_binding.bind_output( + name=name, + device_type=device.type, + device_id=0 if device.type == "cpu" else device.index, + element_type=(np.float16 if self.use_fp16 else np.float32), + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + + return io_binding + + def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False): + sess_options = ort.SessionOptions() + ep = ("CUDAExecutionProvider", {"device_id": device_id}) if device_id >= 0 else "CPUExecutionProvider" + self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep]) + + self.device = torch.device("cuda", device_id) if torch.cuda.is_available() else torch.device("cpu") + self.use_fp16 = use_fp16 + self.use_buffer_share = use_buffer_share + self.packed_kv = packed_kv + self.use_step = use_step + + self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + self.tokenizer.pad_token = "[PAD]" + + def generate(self, prompt, max_length): + encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) + + inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict) + + all_token_ids = inputs["input_ids"].clone() + batch_size, sequence_length = all_token_ids.shape + + current_length = sequence_length + has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool) + + while current_length < max_length: + io_binding = self.apply_io_binding(self.sess, inputs, outputs) + + io_binding.synchronize_inputs() + self.sess.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + # Sample with argmax (greedy search) + next_token_logits = outputs["logits"][:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + + # Check if we previously reached EOS token id or if generated token id is EOS token id + has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id + + # Determine which new tokens to add to list of all token ids + # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't) + tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1]) + all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) + + # Return early if all batch entries have reached EOS token id + if torch.all(has_eos): + break + + # Update inputs for next inference run + current_length += 1 + inputs["input_ids"] = tokens_to_add.to(torch.int32) + if self.use_step: + inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64) + inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1).to( + torch.int32 + ) + + # Set logits to zeros for next inference run and re-use memory buffer + if outputs["logits"].shape[1] != 1: + outputs["logits"] = outputs["logits"][:, :1, :].contiguous() + outputs["logits"].zero_() + + if not self.use_buffer_share: + for i in range(self.num_layers): + if not self.packed_kv: + inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"] + inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"] + else: + inputs[f"past_{i}"] = outputs[f"present_{i}"] + + new_sequence_length = inputs["attention_mask"].shape[1] + present_shape = ( + (2, batch_size, self.num_heads, new_sequence_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, new_sequence_length, self.head_size) + ) + for i in range(self.num_layers): + present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) + outputs.update( + {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()} + ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + + texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True) + return texts + + +def run_phi2(onnx_model_path, use_buffer_share, device_id, packed_kv=False, use_fp16=True, use_step=False): + prompt = [ + '''```python + def print_prime(n): + """ + Print all primes between 1 and n + """''' + ] + + generator = ORTGenerator(onnx_model_path) + generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step) + texts = generator.generate(prompt, max_length=200) + + for i in range(len(texts)): + print("Prompt: ", prompt[i]) + print("Texts: ", texts[i]) diff --git a/onnxruntime/python/tools/transformers/models/phi2/requirements.txt b/onnxruntime/python/tools/transformers/models/phi2/requirements.txt new file mode 100644 index 0000000000..af6f441c14 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/requirements.txt @@ -0,0 +1,3 @@ +onnx>=1.15.0 +transformers>=4.36.2 +onnxscript>=0.1.0.dev20240126 diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 9d1066b6e3..0e20b1f871 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -82,6 +82,10 @@ class OnnxModel: output_name_to_node[output_name] = node return output_name_to_node + def functions(self): + all_functions = [list(self.model.functions)] + return all_functions + def nodes(self): all_nodes = [] for graph in self.graphs(): @@ -733,6 +737,7 @@ class OnnxModel: "node_block_list", "force_fp16_initializers", "force_fp16_inputs", + "use_bfloat16_as_blocked_nodes_dtype", ] if key in kwargs } diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py new file mode 100644 index 0000000000..df8830b0d0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -0,0 +1,839 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import List, Optional + +import numpy as np +from dynamo_onnx_helper import DynamoOnnxHelper +from fusion_base import Fusion +from fusion_options import AttentionOpType, FusionOptions +from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization +from fusion_utils import NumpyHelper +from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class ProcessGemmWFunc: + def __call__(self, x): + return np.transpose(x, (1, 0)) + + +class ProcessMatMulQFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[0], (1, 0)) + + +class ProcessMatMulKFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[1], (1, 0)) + + +class ProcessMatMulVFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[2], (1, 0)) + + +class ProcessBiasQFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[0] + return x + + +class ProcessBiasKFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[1] + return x + + +class ProcessBiasVFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[2] + return x + + +class ProcessRotCacheFunc: + def __call__(self, x): + # half rotary embedding + assert len(x.shape) == 2 + if x.shape[1] == 32: + return x[:, 0:16] + return x + + +# TODO: move to a seperate file +class Fission(Fusion): + def __init__( + self, + model: OnnxModel, + nodes_to_find: List[str], + ): + super().__init__(model, "DONOTUSE", nodes_to_find) + + def set_attention_op_type(self, attn_op_type: AttentionOpType): + self.attn_op_type = attn_op_type + + def get_uname(self, layer_id, name): + return name + "_" + str(layer_id) + + def get_io_by_name(self, node, name): + for input in node.input: + if input == name or input.endswith(name) or input.startswith(name): + return input + for output in node.output: + if output == name or output.endswith(name) or output.startswith(name): + return output + raise Exception(f"input {name} not found in node {node.name}") + + def process_initializer(self, initializer_name, functor, custom_name=None): + i = self.model.get_initializer(initializer_name) + i_np_array = NumpyHelper.to_array(i) + processed_i_np_array = functor(i_np_array) + new_tensor = helper.make_tensor( + initializer_name + "_processed" if custom_name is None else custom_name, + data_type=TensorProto.FLOAT, + dims=processed_i_np_array.shape, + vals=processed_i_np_array.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(new_tensor, self.this_graph_name) + return new_tensor.name + + def add_fp32_value_info(self, name): + new_value_info = self.model.graph().value_info.add() + new_value_info.name = name + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + + def add_int64_value_info(self, name): + new_value_info = self.model.graph().value_info.add() + new_value_info.name = name + new_value_info.type.tensor_type.elem_type = TensorProto.INT64 + + def replace_fp32_value_info(self, name, shape): + for value_info in self.model.graph().value_info: + if value_info.name == name: + self.model.graph().value_info.remove(value_info) + break + new_value_info = helper.make_tensor_value_info( + name, + elem_type=TensorProto.FLOAT, + shape=shape, + ) + self.model.graph().value_info.extend([new_value_info]) + + def set_unique_name_and_add_nodes( + self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str] + ): + for new_node in subgraph_nodes: + for i, name in enumerate(new_node.input): + if name == "": + continue + elif name not in layer_known_edges_names: + new_node.input[i] = self.get_uname(layer_id, name) + self.add_fp32_value_info(new_node.input[i]) + for i, name in enumerate(new_node.output): + if name == "": + continue + elif name not in layer_known_edges_names: + new_node.output[i] = self.get_uname(layer_id, name) + self.add_fp32_value_info(new_node.output[i]) + new_node.name = self.get_uname(layer_id, new_node.name) + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + node = helper.make_node( + "LayerNormalization", + inputs=inputs, + outputs=outputs, + name=prefix + "_LayerNormalization", + epsilon=9.999999747378752e-06, + ) + return [node] + + def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + matmul = helper.make_node( + "MatMul", + inputs=[inputs[0], inputs[1]], + outputs=[prefix + "matmul_out"], + name=prefix + "MatMul", + ) + add = helper.make_node( + "Add", + inputs=[prefix + "matmul_out", inputs[2]], + outputs=outputs, + name=prefix + "Bias", + ) + return [matmul, add] + + def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32): + assert len(inputs) == 4 + assert len(outputs) == 1 + node = helper.make_node( + "RotaryEmbedding", + inputs=inputs, + outputs=outputs, + name=prefix + "RotaryEmbedding", + domain="com.microsoft", + rotary_embedding_dim=rot_dim, + num_heads=num_heads, + ) + return [node] + + def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 1 + assert len(outputs) == 1 + node = helper.make_node( + "FastGelu", + inputs=inputs, + outputs=outputs, + name=prefix + "FastGelu", + domain="com.microsoft", + ) + return [node] + + def add(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 2 + assert len(outputs) == 1 + node = helper.make_node( + "Add", + inputs=inputs, + outputs=outputs, + name=prefix + "Add", + ) + return [node] + + def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 8 + assert len(outputs) == 3 + node = helper.make_node( + "MultiHeadAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "MultiHeadAttention", + domain="com.microsoft", + num_heads=num_heads, + unidirectional=1, + ) + return [node] + + def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 7 + assert len(outputs) == 3 + node = helper.make_node( + "GroupQueryAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "GroupQueryAttention", + domain="com.microsoft", + num_heads=num_heads, + kv_num_heads=num_heads, + ) + return [node] + + def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 5 + assert len(outputs) == 2 + node = helper.make_node( + "Attention", + inputs=inputs, + outputs=outputs, + name=prefix + "Attention", + domain="com.microsoft", + num_heads=num_heads, + unidirectional=1, + do_rotary=1, + rotary_embedding_dim=32, + ) + return [node] + + +class Phi2PreProcessor(DynamoOnnxHelper): + def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): + super().__init__(model) + self.num_hidden_layers = 32 + self.num_attention_heads = num_heads + self.hidden_size = hidden_size + + self.phi2_edge_dict = self.get_phi2_edge_dict() + self.func_name = "modeling_phi_PhiModel_model_1" + + def get_phi2_edge_dict(self) -> dict: + edge_dict = {} + edge_dict["lm_head_1"] = "logits" + edge_dict["l_input_ids_"] = "input_ids" + edge_dict["key_states"] = "past_key_0" + edge_dict["value_states"] = "past_value_0" + for i in range(self.num_hidden_layers): + edge_dict[f"key_states_{i}"] = f"past_key_{i}" + edge_dict[f"value_states_{i}"] = f"past_value_{i}" + edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}" + edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}" + return edge_dict + + def simplify_phi2_op_type(self): + phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers" + for node in self.model.graph.node: + index = node.op_type.find(phi2_transformer_layer_name) + if index != -1: + node.op_type = node.op_type[index:] + + def process_graph_io(self, attn_op_type: AttentionOpType): + self.use_attn = attn_op_type == AttentionOpType.Attention + graph = self.model.graph + new_inputs = [] + for vi in graph.input: + if "input_ids" in vi.name: + vi_iid = helper.make_tensor_value_info( + vi.name, + elem_type=TensorProto.INT32, + shape=["batch_size", "seq_len"], + ) + vi_pid = helper.make_tensor_value_info( + "step", + elem_type=TensorProto.INT64, + shape=[1], + ) + vi_mask = helper.make_tensor_value_info( + "attention_mask", + elem_type=TensorProto.INT32, + shape=["batch_size", "seq_len"], + ) + new_inputs.extend([vi_iid, vi_pid, vi_mask]) + if not self.use_attn: + if "past_key" in vi.name or "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "past_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) + else: + if "past_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name.replace("past_key", "past"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + self.num_attention_heads, + "past_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) + + graph.ClearField("input") + graph.input.extend(new_inputs) + + new_outputs = [] + for i, vi in enumerate(graph.output): + if i == 0: + new_outputs.extend([vi]) + else: + if not self.use_attn: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) + else: + if "present_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name.replace("present_key", "present"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) + + graph.ClearField("output") + graph.output.extend(new_outputs) + + def preprocess_onnx(self, attn_op_type: AttentionOpType): + function_name = None + for func in self.model.functions: + if func.name.endswith(self.func_name): + function_name = func.name + break + assert function_name is not None + self.unroll_function(function_name) + self.update_edges(self.phi2_edge_dict) + self.simplify_phi2_op_type() + self.remove_dropout_layer() + self.process_graph_io(attn_op_type) + + +class FissionTransformerEmbeddingPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 2 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + embedding = self.get_io_by_name(node, "embed_tokens.weight") + + layer_known_edges_names = [input, output, embedding] + + subgraph_nodes = [ + helper.make_node( + "Gather", + inputs=[embedding, input], + outputs=[output], + name="Embedding_Gather", + ), + ] + + self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names) + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerLayerNormPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 3 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + ln_weight = self.get_io_by_name(node, "final_layernorm.weight") + ln_bias = self.get_io_by_name(node, "final_layernorm.bias") + + layer_known_edges_names = [input, output, ln_weight, ln_bias] + + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final")) + + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) + + self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerCausalLMHeadPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 5 + assert len(node.output) == 1 + + input = node.input[2] + output = node.output[0] + + fc_weight = self.process_initializer(self.get_io_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) + fc_bias = self.get_io_by_name(node, "lm_head.bias") + + layer_known_edges_names = [input, output, fc_weight, fc_bias] + + subgraph_nodes = [] + subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_")) + + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) + + self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerBlockPhi(Fission): + def __init__( + self, + model: OnnxModel, + num_heads: int, + ): + self.num_heads = num_heads + max_num_layers = 32 + self.func_to_layer_id = {} + nodes_to_find = [] + for layer in range(max_num_layers): + func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1" + nodes_to_find.append(func_name) + self.func_to_layer_id[func_name] = layer + + super().__init__(model, nodes_to_find) + + def get_layer_id(self, node): + return self.func_to_layer_id[node.op_type] + + def get_gqa_aux_nodes(self): + gqa_aux_nodes = [ + helper.make_node( + "Cast", + inputs=["attention_mask"], + outputs=["mask_int64"], + name="Cast_gqa_aux_0", + to=TensorProto.INT64, + ), + helper.make_node( + "ReduceSum", + inputs=["mask_int64", "one"], + outputs=["mask_row_sums"], + name="ReduceSum_gqa_aux", + ), + helper.make_node( + "Sub", + inputs=["mask_row_sums", "one"], + outputs=["seqlens_k_int64"], + name="Sub_gqa_aux", + ), + helper.make_node( + "Cast", + inputs=["seqlens_k_int64"], + outputs=["seqlens_k"], + name="Cast_gqa_aux_1", + to=TensorProto.INT32, + ), + helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"), + helper.make_node( + "Gather", + inputs=["mask_shape", "one"], + outputs=["total_seq_len_int64"], + name="Gather_gqa_aux_0", + axis=0, + ), + helper.make_node( + "Cast", + inputs=["total_seq_len_int64"], + outputs=["total_sequence_length"], + name="Cast_gqa_aux_2", + to=TensorProto.INT32, + ), + ] + return gqa_aux_nodes + + def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name): + q_weight = self.model.get_initializer(q_w) + k_weight = self.model.get_initializer(k_w) + v_weight = self.model.get_initializer(v_w) + qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0)) + kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0)) + vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0)) + qkv_weight = np.stack((qw, kw, vw), axis=1) + + q_bias = self.model.get_initializer(q_b) + k_bias = self.model.get_initializer(k_b) + v_bias = self.model.get_initializer(v_b) + qb = NumpyHelper.to_array(q_bias) + kb = NumpyHelper.to_array(k_bias) + vb = NumpyHelper.to_array(v_bias) + qkv_bias = np.stack((qb, kb, vb), axis=0) + + hidden_size = qkv_weight.shape[0] + + weight = helper.make_tensor( + weight_name, + data_type=TensorProto.FLOAT, + dims=[hidden_size, hidden_size * 3], + vals=qkv_weight.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(weight, self.this_graph_name) + + bias = helper.make_tensor( + bias_name, + data_type=TensorProto.FLOAT, + dims=[hidden_size * 3], + vals=qkv_bias.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(bias, self.this_graph_name) + + self.add_fp32_value_info(weight.name) + self.add_fp32_value_info(bias.name) + + return weight_name, bias_name + + def fuse( + self, + node, + input_name_to_nodes, + output_name_to_node, + ): + logger.info("Optimizing %s...", node.name) + + logger.info(f"AttentionOpType: {self.attn_op_type}") + + layer_id = self.get_layer_id(node) + + i_hidden_states = node.input[0] + i_key_cache = self.get_io_by_name(node, "past_key") + i_value_cache = self.get_io_by_name(node, "past_value") + + o_hidden_states = node.output[3] + o_key_cache = self.get_io_by_name(node, "present_key") + o_value_cache = self.get_io_by_name(node, "present_value") + + ln_weight = self.get_io_by_name(node, "input_layernorm.weight") + ln_bias = self.get_io_by_name(node, "input_layernorm.bias") + + attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = ( + None, + None, + None, + None, + None, + None, + ) + attn_qkv_weight, attn_qkv_bias = None, None + cos_cache, sin_cache = None, None + + if self.attn_op_type != AttentionOpType.Attention: + attn_q_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() + ) + attn_k_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() + ) + attn_v_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() + ) + attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias") + attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias") + attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias") + + cos_cache = self.process_initializer( + self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() + ) + sin_cache = self.process_initializer( + self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() + ) + else: + attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm( + self.get_io_by_name(node, "self_attn.q_proj.weight"), + self.get_io_by_name(node, "self_attn.k_proj.weight"), + self.get_io_by_name(node, "self_attn.v_proj.weight"), + self.get_io_by_name(node, "self_attn.q_proj.bias"), + self.get_io_by_name(node, "self_attn.k_proj.bias"), + self.get_io_by_name(node, "self_attn.v_proj.bias"), + self.get_uname(layer_id, "attn_qkv_weight"), + self.get_uname(layer_id, "attn_qkv_bias"), + ) + + attn_out_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() + ) + attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias") + + mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) + mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) + mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias") + mlp_fc2_bias = self.get_io_by_name(node, "mlp.fc2.bias") + + layer_known_edges_names = [] + layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache]) + layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache]) + layer_known_edges_names.extend([ln_weight, ln_bias]) + if self.attn_op_type != AttentionOpType.Attention: + layer_known_edges_names.extend( + [ + attn_q_weight, + attn_q_bias, + attn_k_weight, + attn_k_bias, + attn_v_weight, + attn_v_bias, + cos_cache, + sin_cache, + ] + ) + else: + layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias]) + layer_known_edges_names.extend( + [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias] + ) + layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"]) + + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"])) + subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_")) + subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_")) + subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"])) + subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_")) + subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1")) + subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2")) + if self.attn_op_type != AttentionOpType.Attention: + subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) + subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_")) + subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_")) + if self.attn_op_type == AttentionOpType.MultiHeadAttention: + subgraph_nodes.extend( + self.mha( + ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache], + ["attn_out", o_key_cache, o_value_cache], + ) + ) + elif self.attn_op_type == AttentionOpType.GroupQueryAttention: + subgraph_nodes.extend( + self.gqa( + [ + "query_rot", + "key_rot", + "value", + i_key_cache, + i_value_cache, + "seqlens_k", + "total_sequence_length", + ], + ["attn_out", o_key_cache, o_value_cache], + ) + ) + if layer_id == 0: + gqa_aux_nodes = self.get_gqa_aux_nodes() + for new_node in gqa_aux_nodes: + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.model.add_initializer( + numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name + ) + else: + past_name = f"past_{layer_id}" + present_name = f"present_{layer_id}" + layer_known_edges_names.extend([past_name, present_name]) + subgraph_nodes.extend( + self.attention( + ["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name] + ) + ) + + self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names) + + self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class PhiOnnxModel(OnnxModel): + def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): + super().__init__(model) + self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size) + self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads) + self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self) + self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self) + self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self) + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + assert options is not None + attn_op_type = options.attention_op_type + + self.fission_transformer_block.set_attention_op_type(attn_op_type) + + self.phi2_preprocessor.preprocess_onnx(attn_op_type) + + self.fission_transformer_block.apply() + self.fission_transformer_layernorm.apply() + self.fission_causal_lm_head.apply() + self.fission_transformer_embedding.apply() + + super().prune_graph() + + # SLN ctor is placed here intentionally to delay the symbolic shape inference + self.fuse_sln = FusionSkipLayerNormalization(self) + self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self) + self.fuse_sln.apply() + self.fuse_bias_sln.apply() + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "Attention", + "MultiHeadAttention", + "GroupQueryAttention", + "Gelu", + "BiasGelu", + "FastGelu", + "LayerNormalization", + "SkipLayerNormalization", + ] + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators: {op_count}") + return op_count + + def is_fully_optimized(self, fused_op_count=None): + """ + Returns True when the model is fully optimized. + """ + if fused_op_count is None: + fused_op_count = self.get_fused_operator_statistics() + + def op_count(op_name: str): + return fused_op_count.get(op_name) or 0 + + attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("GroupQueryAttention") + gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu") + layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization") + + is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention) + + if layer_norm == 0: + logger.debug("Layer Normalization not fused") + + if gelu == 0: + logger.debug("Gelu (or FastGelu) not fused") + + if attention == 0: + logger.warning("Attention (or MultiHeadAttention) not fused") + + return is_perfect diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ba61f4f6e4..ce0be6b344 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -34,6 +34,7 @@ from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_clip import ClipOnnxModel from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel +from onnx_model_phi import PhiOnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel from onnx_model_unet import UnetOnnxModel @@ -58,6 +59,7 @@ MODEL_TYPES = { "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), "conformer": (ConformerOnnxModel, "pytorch", 1), + "phi": (PhiOnnxModel, "pytorch", 0), } diff --git a/setup.py b/setup.py index 67d34b065a..03e1cb75ba 100644 --- a/setup.py +++ b/setup.py @@ -419,6 +419,7 @@ packages = [ "onnxruntime.transformers.models.gpt2", "onnxruntime.transformers.models.llama", "onnxruntime.transformers.models.longformer", + "onnxruntime.transformers.models.phi2", "onnxruntime.transformers.models.t5", "onnxruntime.transformers.models.stable_diffusion", "onnxruntime.transformers.models.whisper",