Add script to convert phi2 to ort-vllm compatible (#19429)

### Description
<!-- Describe your changes. -->

1. add option to export onnx compatiable with ort_vllm. This makes sure
that onnx model only leverages on paged attn from vllm. It's intended to
use internally so not mentioned in readme.
2. add details in ORT
installation(https://github.com/microsoft/onnxruntime/pull/19338#discussion_r1476906190)


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: wejoncy <wejoncy@163.com>
This commit is contained in:
Ye Wang 2024-02-08 01:03:06 +00:00 committed by GitHub
parent 0d10c7f3c1
commit 19952c5b35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 183 additions and 43 deletions

View file

@ -213,6 +213,7 @@ class SymbolicShapeInference:
"NhwcConv": self._infer_NhwcConv,
"PackedAttention": self._infer_PackedAttention,
"PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
"PagedAttention": self._infer_PagedAttention,
"PythonOp": self._infer_PythonOp,
"QuantizeLinear": self._infer_QuantizeLinear,
"QuickGelu": self._infer_FastGelu,
@ -470,6 +471,7 @@ class SymbolicShapeInference:
"SkipLayerNormalization",
"SkipSimplifiedLayerNormalization",
"PackedAttention",
"PagedAttention",
"PythonOp",
"MultiHeadAttention",
"GroupNorm",
@ -2412,6 +2414,9 @@ class SymbolicShapeInference:
def _infer_GroupNorm(self, node): # noqa: N802
self._propagate_shape_and_type(node)
def _infer_PagedAttention(self, node): # noqa: N802
self._propagate_shape_and_type(node)
def _infer_GroupQueryAttention(self, node): # noqa: N802
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type

View file

@ -73,20 +73,32 @@ class DynamoOnnxHelper:
return self.update_edges(edge_mapping)
def remove_dropout_layer(self) -> None:
def remove_function(self, func_name: str, input_id: int, output_id: int) -> None:
"""
Removes the dropout layer in the model.
Removes the function in the model.
"""
logging.info("Removing dropout layer...")
edge_mapping = {}
nodes_to_remove = []
for node in self.model.graph.node:
if node.op_type.find("Dropout") != -1:
assert len(node.input) == 1
assert len(node.output) == 1
edge_mapping[node.output[0]] = node.input[0]
if node.op_type.find(func_name) != -1:
edge_mapping[node.input[input_id]] = node.output[output_id]
nodes_to_remove.append(node)
for node in nodes_to_remove:
self.model.graph.node.remove(node)
self.update_edges(edge_mapping)
def remove_dropout_layer(self) -> None:
"""
Removes the dropout layer in the model.
"""
logging.info("Removing dropout layer...")
self.remove_function("Dropout", 0, 0)
def remove_lm_head_layer(self) -> None:
"""
Removes the LM head layer in the model.
"""
logging.info("Removing LM head layer...")
# bugbug: need to copy the right vi over
self.remove_function("Linear_lm_head", 2, 0)

View file

@ -24,6 +24,7 @@ class AttentionOpType(Enum):
Attention = "Attention"
MultiHeadAttention = "MultiHeadAttention"
GroupQueryAttention = "GroupQueryAttention"
PagedAttention = "PagedAttention"
def __str__(self):
return self.value

View file

@ -11,6 +11,7 @@ To export ONNX, PyTorch version 2.2.0 or higher is required. The [official websi
**There are two options to run the conversion script:**\
_From source:_
```bash
# Default onnxruntime package is built with CUDA 11.8. For CUDA 12.x, refer to https://onnxruntime.ai/docs/install/#python-installs
pip install onnxruntime-gpu==1.17.0 # or onnxruntime==1.17.0 if using cpu
git clone git@github.com:microsoft/onnxruntime.git
cd onnxruntime/onnxruntime/python/tools/transformers

View file

@ -136,14 +136,18 @@ class ConvertPhi2ToONNX:
self.precision == Precision.FLOAT16 or self.precision == Precision.INT4
) and self.attn_op_type != AttentionOpType.MultiHeadAttention:
# We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
node_block_list = [
"GroupQueryAttention_29",
"GroupQueryAttention_30",
"GroupQueryAttention_31",
"Attention_29",
"Attention_30",
"Attention_31",
]
node_block_list = (
[
"GroupQueryAttention_29",
"GroupQueryAttention_30",
"GroupQueryAttention_31",
"Attention_29",
"Attention_30",
"Attention_31",
]
if self.attn_op_type != AttentionOpType.PagedAttention
else []
) # TODO: temp setting for paged attention
logging.info("Converting onnx model to float16/bfloat16...")
optimizer.convert_float_to_float16(
keep_io_types=False,
@ -220,6 +224,20 @@ def parse_arguments():
help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
)
parser.add_argument(
"--fp16_vllm",
required=False,
action="store_true",
help="Generate fp16 ONNX model for ORT VLLM",
)
parser.add_argument(
"--int4_vllm",
required=False,
action="store_true",
help="Generate int4 ONNX model for ORT VLLM",
)
parser.add_argument(
"--overwrite",
required=False,
@ -336,6 +354,16 @@ def main():
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"),
),
"fp16_vllm": (
AttentionOpType.PagedAttention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_vllm.onnx"),
),
"int4_vllm": (
AttentionOpType.PagedAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_vllm.onnx"),
),
}
if not args.skip_export:
@ -403,6 +431,22 @@ def main():
)
)
if args.fp16_vllm:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["fp16_vllm"]),
)
)
if args.int4_vllm:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["int4_vllm"]),
)
)
[p.start() for p in processes]
[p.join() for p in processes]
@ -450,8 +494,8 @@ def main():
device_id=args.device_id,
packed_kv=True,
)
if args.fp32_cpu or args.int4_cpu:
raise NotImplementedError("CPU inference example is not implemented yet.")
if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm:
raise NotImplementedError("CPU/vllm inference example is not implemented yet.")
if __name__ == "__main__":

View file

@ -255,6 +255,30 @@ class Fission(Fusion):
)
return [node]
def paged_attn(
self,
inputs: List[str],
outputs: List[str],
prefix: str = "",
num_heads=32,
head_size=80,
scale=0.11180339753627777,
):
assert len(inputs) == 6
assert len(outputs) == 1
node = helper.make_node(
"PagedAttention",
inputs=inputs,
outputs=outputs,
name=prefix + "PagedAttention",
domain="vllm.ort.ext",
num_heads=num_heads,
num_kv_heads=num_heads,
head_size=head_size,
scale=scale,
)
return [node]
class Phi2PreProcessor(DynamoOnnxHelper):
def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
@ -288,32 +312,46 @@ class Phi2PreProcessor(DynamoOnnxHelper):
def process_graph_io(self, attn_op_type: AttentionOpType):
self.use_attn = attn_op_type == AttentionOpType.Attention
self.use_vllm = attn_op_type == AttentionOpType.PagedAttention
graph = self.model.graph
new_inputs = []
for vi in graph.input:
if "input_ids" in vi.name:
vi_iid = helper.make_tensor_value_info(
vi.name,
elem_type=TensorProto.INT32,
elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64,
shape=["batch_size", "seq_len"],
)
vi_pid = helper.make_tensor_value_info(
vi_step = helper.make_tensor_value_info(
"step",
elem_type=TensorProto.INT64,
shape=[1],
)
vi_pid = helper.make_tensor_value_info(
"position_ids",
elem_type=TensorProto.INT64,
shape=["batch_size", "seq_len"],
)
vi_mask = helper.make_tensor_value_info(
"attention_mask",
elem_type=TensorProto.INT32,
shape=["batch_size", "seq_len"],
)
new_inputs.extend([vi_iid, vi_pid, vi_mask])
if not self.use_attn:
if "past_key" in vi.name or "past_value" in vi.name:
vi_meta = helper.make_tensor_value_info(
"input_metadata",
elem_type=TensorProto.INT64,
shape=[1],
)
new_inputs.extend([vi_iid, vi_step, vi_mask]) if not self.use_vllm else new_inputs.extend(
[vi_iid, vi_pid, vi_meta]
)
if self.use_attn:
if "past_key" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name,
vi.name.replace("past_key", "past"),
elem_type=vi.type.tensor_type.elem_type,
shape=[
2,
"batch_size",
self.num_attention_heads,
"past_seq_len",
@ -321,13 +359,32 @@ class Phi2PreProcessor(DynamoOnnxHelper):
],
)
new_inputs.extend([vi_cache])
else:
elif self.use_vllm:
if "past_key" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name.replace("past_key", "past"),
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"],
)
new_inputs.extend([vi_cache])
if "past_value" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=[
"num_blocks",
"num_heads",
"head_size",
"block_size",
],
)
new_inputs.extend([vi_cache])
else:
if "past_key" in vi.name or "past_value" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=[
2,
"batch_size",
self.num_attention_heads,
"past_seq_len",
@ -344,19 +401,7 @@ class Phi2PreProcessor(DynamoOnnxHelper):
if i == 0:
new_outputs.extend([vi])
else:
if not self.use_attn:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=[
"batch_size",
self.num_attention_heads,
"total_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_outputs.extend([vi_cache])
else:
if self.use_attn:
if "present_key" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name.replace("present_key", "present"),
@ -370,6 +415,20 @@ class Phi2PreProcessor(DynamoOnnxHelper):
],
)
new_outputs.extend([vi_cache])
elif self.use_vllm:
pass
else:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=[
"batch_size",
self.num_attention_heads,
"total_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_outputs.extend([vi_cache])
graph.ClearField("output")
graph.output.extend(new_outputs)
@ -385,6 +444,8 @@ class Phi2PreProcessor(DynamoOnnxHelper):
self.update_edges(self.phi2_edge_dict)
self.simplify_phi2_op_type()
self.remove_dropout_layer()
if attn_op_type == AttentionOpType.PagedAttention:
self.remove_lm_head_layer()
self.process_graph_io(attn_op_type)
@ -694,7 +755,9 @@ class FissionTransformerBlockPhi(Fission):
layer_known_edges_names.extend(
[attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]
)
layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"])
layer_known_edges_names.extend(
["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"]
)
subgraph_nodes = []
subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"]))
@ -708,8 +771,9 @@ class FissionTransformerBlockPhi(Fission):
subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_"))
subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_"))
pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step"
subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_"))
subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))
if self.attn_op_type == AttentionOpType.MultiHeadAttention:
subgraph_nodes.extend(
self.mha(
@ -740,6 +804,13 @@ class FissionTransformerBlockPhi(Fission):
self.model.add_initializer(
numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
)
elif self.attn_op_type == AttentionOpType.PagedAttention:
subgraph_nodes.extend(
self.paged_attn(
["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"],
["attn_out"],
)
)
else:
past_name = f"past_{layer_id}"
present_name = f"present_{layer_id}"
@ -798,6 +869,7 @@ class PhiOnnxModel(OnnxModel):
"Attention",
"MultiHeadAttention",
"GroupQueryAttention",
"PagedAttention",
"Gelu",
"BiasGelu",
"FastGelu",
@ -821,7 +893,12 @@ class PhiOnnxModel(OnnxModel):
def op_count(op_name: str):
return fused_op_count.get(op_name) or 0
attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("GroupQueryAttention")
attention = (
op_count("Attention")
+ op_count("MultiHeadAttention")
+ op_count("GroupQueryAttention")
+ op_count("PagedAttention")
)
gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")