mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-11 00:49:31 +00:00
Add script to convert phi2 to ort-vllm compatible (#19429)
### Description <!-- Describe your changes. --> 1. add option to export onnx compatiable with ort_vllm. This makes sure that onnx model only leverages on paged attn from vllm. It's intended to use internally so not mentioned in readme. 2. add details in ORT installation(https://github.com/microsoft/onnxruntime/pull/19338#discussion_r1476906190) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: wejoncy <wejoncy@163.com>
This commit is contained in:
parent
0d10c7f3c1
commit
19952c5b35
6 changed files with 183 additions and 43 deletions
|
|
@ -213,6 +213,7 @@ class SymbolicShapeInference:
|
|||
"NhwcConv": self._infer_NhwcConv,
|
||||
"PackedAttention": self._infer_PackedAttention,
|
||||
"PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
|
||||
"PagedAttention": self._infer_PagedAttention,
|
||||
"PythonOp": self._infer_PythonOp,
|
||||
"QuantizeLinear": self._infer_QuantizeLinear,
|
||||
"QuickGelu": self._infer_FastGelu,
|
||||
|
|
@ -470,6 +471,7 @@ class SymbolicShapeInference:
|
|||
"SkipLayerNormalization",
|
||||
"SkipSimplifiedLayerNormalization",
|
||||
"PackedAttention",
|
||||
"PagedAttention",
|
||||
"PythonOp",
|
||||
"MultiHeadAttention",
|
||||
"GroupNorm",
|
||||
|
|
@ -2412,6 +2414,9 @@ class SymbolicShapeInference:
|
|||
def _infer_GroupNorm(self, node): # noqa: N802
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
def _infer_PagedAttention(self, node): # noqa: N802
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
def _infer_GroupQueryAttention(self, node): # noqa: N802
|
||||
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
|
||||
|
|
|
|||
|
|
@ -73,20 +73,32 @@ class DynamoOnnxHelper:
|
|||
|
||||
return self.update_edges(edge_mapping)
|
||||
|
||||
def remove_dropout_layer(self) -> None:
|
||||
def remove_function(self, func_name: str, input_id: int, output_id: int) -> None:
|
||||
"""
|
||||
Removes the dropout layer in the model.
|
||||
Removes the function in the model.
|
||||
"""
|
||||
logging.info("Removing dropout layer...")
|
||||
edge_mapping = {}
|
||||
nodes_to_remove = []
|
||||
for node in self.model.graph.node:
|
||||
if node.op_type.find("Dropout") != -1:
|
||||
assert len(node.input) == 1
|
||||
assert len(node.output) == 1
|
||||
edge_mapping[node.output[0]] = node.input[0]
|
||||
if node.op_type.find(func_name) != -1:
|
||||
edge_mapping[node.input[input_id]] = node.output[output_id]
|
||||
nodes_to_remove.append(node)
|
||||
for node in nodes_to_remove:
|
||||
self.model.graph.node.remove(node)
|
||||
|
||||
self.update_edges(edge_mapping)
|
||||
|
||||
def remove_dropout_layer(self) -> None:
|
||||
"""
|
||||
Removes the dropout layer in the model.
|
||||
"""
|
||||
logging.info("Removing dropout layer...")
|
||||
self.remove_function("Dropout", 0, 0)
|
||||
|
||||
def remove_lm_head_layer(self) -> None:
|
||||
"""
|
||||
Removes the LM head layer in the model.
|
||||
"""
|
||||
logging.info("Removing LM head layer...")
|
||||
# bugbug: need to copy the right vi over
|
||||
self.remove_function("Linear_lm_head", 2, 0)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class AttentionOpType(Enum):
|
|||
Attention = "Attention"
|
||||
MultiHeadAttention = "MultiHeadAttention"
|
||||
GroupQueryAttention = "GroupQueryAttention"
|
||||
PagedAttention = "PagedAttention"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ To export ONNX, PyTorch version 2.2.0 or higher is required. The [official websi
|
|||
**There are two options to run the conversion script:**\
|
||||
_From source:_
|
||||
```bash
|
||||
# Default onnxruntime package is built with CUDA 11.8. For CUDA 12.x, refer to https://onnxruntime.ai/docs/install/#python-installs
|
||||
pip install onnxruntime-gpu==1.17.0 # or onnxruntime==1.17.0 if using cpu
|
||||
git clone git@github.com:microsoft/onnxruntime.git
|
||||
cd onnxruntime/onnxruntime/python/tools/transformers
|
||||
|
|
|
|||
|
|
@ -136,14 +136,18 @@ class ConvertPhi2ToONNX:
|
|||
self.precision == Precision.FLOAT16 or self.precision == Precision.INT4
|
||||
) and self.attn_op_type != AttentionOpType.MultiHeadAttention:
|
||||
# We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
|
||||
node_block_list = [
|
||||
"GroupQueryAttention_29",
|
||||
"GroupQueryAttention_30",
|
||||
"GroupQueryAttention_31",
|
||||
"Attention_29",
|
||||
"Attention_30",
|
||||
"Attention_31",
|
||||
]
|
||||
node_block_list = (
|
||||
[
|
||||
"GroupQueryAttention_29",
|
||||
"GroupQueryAttention_30",
|
||||
"GroupQueryAttention_31",
|
||||
"Attention_29",
|
||||
"Attention_30",
|
||||
"Attention_31",
|
||||
]
|
||||
if self.attn_op_type != AttentionOpType.PagedAttention
|
||||
else []
|
||||
) # TODO: temp setting for paged attention
|
||||
logging.info("Converting onnx model to float16/bfloat16...")
|
||||
optimizer.convert_float_to_float16(
|
||||
keep_io_types=False,
|
||||
|
|
@ -220,6 +224,20 @@ def parse_arguments():
|
|||
help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_vllm",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp16 ONNX model for ORT VLLM",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_vllm",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate int4 ONNX model for ORT VLLM",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
|
|
@ -336,6 +354,16 @@ def main():
|
|||
Precision.INT4,
|
||||
os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"),
|
||||
),
|
||||
"fp16_vllm": (
|
||||
AttentionOpType.PagedAttention,
|
||||
Precision.FLOAT16,
|
||||
os.path.join(output_dir, "phi2_decoder_fp16_vllm.onnx"),
|
||||
),
|
||||
"int4_vllm": (
|
||||
AttentionOpType.PagedAttention,
|
||||
Precision.INT4,
|
||||
os.path.join(output_dir, "phi2_decoder_int4_vllm.onnx"),
|
||||
),
|
||||
}
|
||||
|
||||
if not args.skip_export:
|
||||
|
|
@ -403,6 +431,22 @@ def main():
|
|||
)
|
||||
)
|
||||
|
||||
if args.fp16_vllm:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx,
|
||||
args=(converter, original_onnx_path, *model_type_to_args["fp16_vllm"]),
|
||||
)
|
||||
)
|
||||
|
||||
if args.int4_vllm:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx,
|
||||
args=(converter, original_onnx_path, *model_type_to_args["int4_vllm"]),
|
||||
)
|
||||
)
|
||||
|
||||
[p.start() for p in processes]
|
||||
[p.join() for p in processes]
|
||||
|
||||
|
|
@ -450,8 +494,8 @@ def main():
|
|||
device_id=args.device_id,
|
||||
packed_kv=True,
|
||||
)
|
||||
if args.fp32_cpu or args.int4_cpu:
|
||||
raise NotImplementedError("CPU inference example is not implemented yet.")
|
||||
if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm:
|
||||
raise NotImplementedError("CPU/vllm inference example is not implemented yet.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -255,6 +255,30 @@ class Fission(Fusion):
|
|||
)
|
||||
return [node]
|
||||
|
||||
def paged_attn(
|
||||
self,
|
||||
inputs: List[str],
|
||||
outputs: List[str],
|
||||
prefix: str = "",
|
||||
num_heads=32,
|
||||
head_size=80,
|
||||
scale=0.11180339753627777,
|
||||
):
|
||||
assert len(inputs) == 6
|
||||
assert len(outputs) == 1
|
||||
node = helper.make_node(
|
||||
"PagedAttention",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=prefix + "PagedAttention",
|
||||
domain="vllm.ort.ext",
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
)
|
||||
return [node]
|
||||
|
||||
|
||||
class Phi2PreProcessor(DynamoOnnxHelper):
|
||||
def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
|
||||
|
|
@ -288,32 +312,46 @@ class Phi2PreProcessor(DynamoOnnxHelper):
|
|||
|
||||
def process_graph_io(self, attn_op_type: AttentionOpType):
|
||||
self.use_attn = attn_op_type == AttentionOpType.Attention
|
||||
self.use_vllm = attn_op_type == AttentionOpType.PagedAttention
|
||||
graph = self.model.graph
|
||||
new_inputs = []
|
||||
for vi in graph.input:
|
||||
if "input_ids" in vi.name:
|
||||
vi_iid = helper.make_tensor_value_info(
|
||||
vi.name,
|
||||
elem_type=TensorProto.INT32,
|
||||
elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64,
|
||||
shape=["batch_size", "seq_len"],
|
||||
)
|
||||
vi_pid = helper.make_tensor_value_info(
|
||||
vi_step = helper.make_tensor_value_info(
|
||||
"step",
|
||||
elem_type=TensorProto.INT64,
|
||||
shape=[1],
|
||||
)
|
||||
vi_pid = helper.make_tensor_value_info(
|
||||
"position_ids",
|
||||
elem_type=TensorProto.INT64,
|
||||
shape=["batch_size", "seq_len"],
|
||||
)
|
||||
vi_mask = helper.make_tensor_value_info(
|
||||
"attention_mask",
|
||||
elem_type=TensorProto.INT32,
|
||||
shape=["batch_size", "seq_len"],
|
||||
)
|
||||
new_inputs.extend([vi_iid, vi_pid, vi_mask])
|
||||
if not self.use_attn:
|
||||
if "past_key" in vi.name or "past_value" in vi.name:
|
||||
vi_meta = helper.make_tensor_value_info(
|
||||
"input_metadata",
|
||||
elem_type=TensorProto.INT64,
|
||||
shape=[1],
|
||||
)
|
||||
new_inputs.extend([vi_iid, vi_step, vi_mask]) if not self.use_vllm else new_inputs.extend(
|
||||
[vi_iid, vi_pid, vi_meta]
|
||||
)
|
||||
if self.use_attn:
|
||||
if "past_key" in vi.name:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name,
|
||||
vi.name.replace("past_key", "past"),
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=[
|
||||
2,
|
||||
"batch_size",
|
||||
self.num_attention_heads,
|
||||
"past_seq_len",
|
||||
|
|
@ -321,13 +359,32 @@ class Phi2PreProcessor(DynamoOnnxHelper):
|
|||
],
|
||||
)
|
||||
new_inputs.extend([vi_cache])
|
||||
else:
|
||||
elif self.use_vllm:
|
||||
if "past_key" in vi.name:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name.replace("past_key", "past"),
|
||||
vi.name,
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"],
|
||||
)
|
||||
new_inputs.extend([vi_cache])
|
||||
if "past_value" in vi.name:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name,
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=[
|
||||
"num_blocks",
|
||||
"num_heads",
|
||||
"head_size",
|
||||
"block_size",
|
||||
],
|
||||
)
|
||||
new_inputs.extend([vi_cache])
|
||||
else:
|
||||
if "past_key" in vi.name or "past_value" in vi.name:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name,
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=[
|
||||
2,
|
||||
"batch_size",
|
||||
self.num_attention_heads,
|
||||
"past_seq_len",
|
||||
|
|
@ -344,19 +401,7 @@ class Phi2PreProcessor(DynamoOnnxHelper):
|
|||
if i == 0:
|
||||
new_outputs.extend([vi])
|
||||
else:
|
||||
if not self.use_attn:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name,
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=[
|
||||
"batch_size",
|
||||
self.num_attention_heads,
|
||||
"total_seq_len",
|
||||
self.hidden_size // self.num_attention_heads,
|
||||
],
|
||||
)
|
||||
new_outputs.extend([vi_cache])
|
||||
else:
|
||||
if self.use_attn:
|
||||
if "present_key" in vi.name:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name.replace("present_key", "present"),
|
||||
|
|
@ -370,6 +415,20 @@ class Phi2PreProcessor(DynamoOnnxHelper):
|
|||
],
|
||||
)
|
||||
new_outputs.extend([vi_cache])
|
||||
elif self.use_vllm:
|
||||
pass
|
||||
else:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name,
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=[
|
||||
"batch_size",
|
||||
self.num_attention_heads,
|
||||
"total_seq_len",
|
||||
self.hidden_size // self.num_attention_heads,
|
||||
],
|
||||
)
|
||||
new_outputs.extend([vi_cache])
|
||||
|
||||
graph.ClearField("output")
|
||||
graph.output.extend(new_outputs)
|
||||
|
|
@ -385,6 +444,8 @@ class Phi2PreProcessor(DynamoOnnxHelper):
|
|||
self.update_edges(self.phi2_edge_dict)
|
||||
self.simplify_phi2_op_type()
|
||||
self.remove_dropout_layer()
|
||||
if attn_op_type == AttentionOpType.PagedAttention:
|
||||
self.remove_lm_head_layer()
|
||||
self.process_graph_io(attn_op_type)
|
||||
|
||||
|
||||
|
|
@ -694,7 +755,9 @@ class FissionTransformerBlockPhi(Fission):
|
|||
layer_known_edges_names.extend(
|
||||
[attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]
|
||||
)
|
||||
layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"])
|
||||
layer_known_edges_names.extend(
|
||||
["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"]
|
||||
)
|
||||
|
||||
subgraph_nodes = []
|
||||
subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"]))
|
||||
|
|
@ -708,8 +771,9 @@ class FissionTransformerBlockPhi(Fission):
|
|||
subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
|
||||
subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
|
||||
subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
|
||||
subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_"))
|
||||
subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_"))
|
||||
pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step"
|
||||
subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_"))
|
||||
subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))
|
||||
if self.attn_op_type == AttentionOpType.MultiHeadAttention:
|
||||
subgraph_nodes.extend(
|
||||
self.mha(
|
||||
|
|
@ -740,6 +804,13 @@ class FissionTransformerBlockPhi(Fission):
|
|||
self.model.add_initializer(
|
||||
numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
|
||||
)
|
||||
elif self.attn_op_type == AttentionOpType.PagedAttention:
|
||||
subgraph_nodes.extend(
|
||||
self.paged_attn(
|
||||
["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"],
|
||||
["attn_out"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
past_name = f"past_{layer_id}"
|
||||
present_name = f"present_{layer_id}"
|
||||
|
|
@ -798,6 +869,7 @@ class PhiOnnxModel(OnnxModel):
|
|||
"Attention",
|
||||
"MultiHeadAttention",
|
||||
"GroupQueryAttention",
|
||||
"PagedAttention",
|
||||
"Gelu",
|
||||
"BiasGelu",
|
||||
"FastGelu",
|
||||
|
|
@ -821,7 +893,12 @@ class PhiOnnxModel(OnnxModel):
|
|||
def op_count(op_name: str):
|
||||
return fused_op_count.get(op_name) or 0
|
||||
|
||||
attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("GroupQueryAttention")
|
||||
attention = (
|
||||
op_count("Attention")
|
||||
+ op_count("MultiHeadAttention")
|
||||
+ op_count("GroupQueryAttention")
|
||||
+ op_count("PagedAttention")
|
||||
)
|
||||
gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
|
||||
layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue