mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Update LLaMA attention fusions (#19200)
### Description This PR updates the LLaMA-2 attention fusions by adding the following. - Loading the PyTorch model from Hugging Face with the `LlamaAttention` class before exporting - Updating the attention mask pattern matching to support another case This PR also fixes [this issue](https://github.com/microsoft/onnxruntime/issues/19040). ### Motivation and Context Recent changes to Hugging Face's `transformers` library break the existing pattern matching. Since the attention fusions aim to change the graph from `LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op` to `LayerNorm Op --> Attention Op --> LayerNorm Op` per layer, ultimately it does not matter what nodes comprise the `Set of Attention Nodes` because they will all be removed and replaced by the `Attention Op` in the end. Therefore, it does not matter whether the `LlamaAttention` class or a different attention class is used to load the PyTorch model before exporting because the expected graphs after the attention fusions will look identical no matter the attention class chosen. By loading the PyTorch model with the `LlamaAttention` class instead of other attention classes (e.g. `LlamaFlashAttention2` or `LlamaSdpaAttention`) and then exporting it to ONNX, the existing pattern matching will continue to work.
This commit is contained in:
parent
eaf047c820
commit
a3ecb63267
5 changed files with 46 additions and 25 deletions
|
|
@ -539,6 +539,8 @@ class FusionRotaryAttention(FusionAttention):
|
|||
|
||||
# attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
|
||||
# attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
|
||||
# attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
|
||||
# attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
|
||||
attn_mask, add_qk_str = "", ""
|
||||
attn_mask_nodes_1 = self.model.match_parent_path(
|
||||
add_qk,
|
||||
|
|
@ -570,6 +572,11 @@ class FusionRotaryAttention(FusionAttention):
|
|||
["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
|
||||
[1, 0, 2, 1, 0, 0, 0],
|
||||
)
|
||||
attn_mask_nodes_7 = self.model.match_parent_path(
|
||||
add_qk,
|
||||
["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
|
||||
[1, 0, 0, 0, 0, 1, 0, 0, 0],
|
||||
)
|
||||
if attn_mask_nodes_1 is not None:
|
||||
_, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
|
||||
attn_mask = slice_mask_1.output[0]
|
||||
|
|
@ -588,6 +595,9 @@ class FusionRotaryAttention(FusionAttention):
|
|||
elif attn_mask_nodes_6 is not None:
|
||||
# The mask has already been reshaped to (B,N,S,T)
|
||||
add_qk_str = attn_mask_nodes_6[0].output[0]
|
||||
elif attn_mask_nodes_7 is not None:
|
||||
# Reshape from (B,1,S,T) to (B,N,S,T)
|
||||
add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
|
||||
else:
|
||||
logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -42,23 +42,6 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama
|
|||
|
||||
To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf).
|
||||
|
||||
As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows:
|
||||
|
||||
```
|
||||
# Before
|
||||
if self.use_cache:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
|
||||
|
||||
|
||||
# After
|
||||
if self.use_cache:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids
|
||||
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
|
||||
```
|
||||
|
||||
### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx)
|
||||
|
||||
Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2.
|
||||
|
|
@ -254,7 +237,7 @@ Here are some examples of how you can benchmark LLaMA-2.
|
|||
|
||||
1. PyTorch without `torch.compile`, FP32
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
|
||||
--benchmark-type hf-pt-eager \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp32 \
|
||||
|
|
@ -266,7 +249,7 @@ python3 -m models.llama.benchmark \
|
|||
|
||||
2. PyTorch with `torch.compile`, FP16
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
|
||||
--benchmark-type hf-pt-compile \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp16 \
|
||||
|
|
@ -278,7 +261,7 @@ python3 -m models.llama.benchmark \
|
|||
|
||||
3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
|
||||
--benchmark-type hf-ort \
|
||||
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
|
|
@ -291,7 +274,7 @@ python3 -m models.llama.benchmark \
|
|||
|
||||
4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
|
||||
--benchmark-type hf-ort \
|
||||
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
|
|
@ -304,7 +287,7 @@ python3 -m models.llama.benchmark \
|
|||
|
||||
5. ONNX Runtime, FP32, Microsoft custom export
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
|
||||
--benchmark-type ort-msft \
|
||||
--ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
|
|
@ -316,7 +299,7 @@ python3 -m models.llama.benchmark \
|
|||
|
||||
6. ONNX Runtime, FP16, Microsoft custom export
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
|
||||
--benchmark-type ort-msft \
|
||||
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
|
|
@ -367,7 +350,7 @@ You can profile a variant by adding the `--profile` flag and providing one batch
|
|||
### Benchmark All
|
||||
You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example.
|
||||
```
|
||||
python3 -m models.llama.benchmark_all \
|
||||
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \
|
||||
--hf-pt-eager \
|
||||
--hf-pt-compile \
|
||||
--hf-ort-dir-path ./llama2-7b-fp16/ \
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import argparse
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from itertools import chain
|
||||
|
||||
import onnx
|
||||
|
|
@ -408,6 +410,31 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov
|
|||
only_onnxruntime=False,
|
||||
)
|
||||
model_opt.save_model_to_file(output_path, use_external_data_format=True)
|
||||
|
||||
# Run symbolic shape inference on optimized model to avoid shape errors during runtime
|
||||
# Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output.
|
||||
# After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output.
|
||||
wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"]
|
||||
source_cmd = [sys.executable, "../symbolic_shape_infer.py"]
|
||||
symbolic_shape_infer_args = [
|
||||
"--input",
|
||||
output_path,
|
||||
"--output",
|
||||
output_path,
|
||||
"--auto_merge",
|
||||
"--save_as_external_data",
|
||||
"--all_tensors_to_one_file",
|
||||
"--external_data_location",
|
||||
os.path.basename(output_path) + ".data",
|
||||
]
|
||||
|
||||
file_path = os.path.dirname(__file__)
|
||||
if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")):
|
||||
main_cmd = wheel_cmd
|
||||
else:
|
||||
main_cmd = source_cmd
|
||||
subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510
|
||||
|
||||
logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
|
||||
if remove_model:
|
||||
remove_existing_model(input_path)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32,
|
|||
if i == rank % (world_size):
|
||||
l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir)
|
||||
l_config.use_cache = True
|
||||
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
|
||||
llama = AutoModelForCausalLM.from_pretrained(
|
||||
location,
|
||||
use_auth_token=use_auth_token,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
git+https://github.com/huggingface/optimum.git
|
||||
optimum>=1.14.1
|
||||
transformers>=4.33.2
|
||||
torch>=2.2.0.dev20230920
|
||||
onnx>=1.14.0
|
||||
|
|
|
|||
Loading…
Reference in a new issue