Update LLaMA attention fusions (#19200)

### Description
This PR updates the LLaMA-2 attention fusions by adding the following.

- Loading the PyTorch model from Hugging Face with the `LlamaAttention`
class before exporting
- Updating the attention mask pattern matching to support another case

This PR also fixes [this
issue](https://github.com/microsoft/onnxruntime/issues/19040).

### Motivation and Context
Recent changes to Hugging Face's `transformers` library break the
existing pattern matching. Since the attention fusions aim to change the
graph from `LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op` to
`LayerNorm Op --> Attention Op --> LayerNorm Op` per layer, ultimately
it does not matter what nodes comprise the `Set of Attention Nodes`
because they will all be removed and replaced by the `Attention Op` in
the end.

Therefore, it does not matter whether the `LlamaAttention` class or a
different attention class is used to load the PyTorch model before
exporting because the expected graphs after the attention fusions will
look identical no matter the attention class chosen. By loading the
PyTorch model with the `LlamaAttention` class instead of other attention
classes (e.g. `LlamaFlashAttention2` or `LlamaSdpaAttention`) and then
exporting it to ONNX, the existing pattern matching will continue to
work.
This commit is contained in:
kunal-vaishnavi 2024-01-19 11:09:24 -08:00 committed by GitHub
parent eaf047c820
commit a3ecb63267
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 46 additions and 25 deletions

View file

@ -539,6 +539,8 @@ class FusionRotaryAttention(FusionAttention):
# attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
# attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
# attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
# attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
attn_mask, add_qk_str = "", ""
attn_mask_nodes_1 = self.model.match_parent_path(
add_qk,
@ -570,6 +572,11 @@ class FusionRotaryAttention(FusionAttention):
["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 2, 1, 0, 0, 0],
)
attn_mask_nodes_7 = self.model.match_parent_path(
add_qk,
["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 0, 0, 0, 1, 0, 0, 0],
)
if attn_mask_nodes_1 is not None:
_, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
attn_mask = slice_mask_1.output[0]
@ -588,6 +595,9 @@ class FusionRotaryAttention(FusionAttention):
elif attn_mask_nodes_6 is not None:
# The mask has already been reshaped to (B,N,S,T)
add_qk_str = attn_mask_nodes_6[0].output[0]
elif attn_mask_nodes_7 is not None:
# Reshape from (B,1,S,T) to (B,N,S,T)
add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
else:
logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
return

View file

@ -42,23 +42,6 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama
To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf).
As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows:
```
# Before
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
# After
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
```
### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx)
Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2.
@ -254,7 +237,7 @@ Here are some examples of how you can benchmark LLaMA-2.
1. PyTorch without `torch.compile`, FP32
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-pt-eager \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
@ -266,7 +249,7 @@ python3 -m models.llama.benchmark \
2. PyTorch with `torch.compile`, FP16
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-pt-compile \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
@ -278,7 +261,7 @@ python3 -m models.llama.benchmark \
3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
@ -291,7 +274,7 @@ python3 -m models.llama.benchmark \
4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
@ -304,7 +287,7 @@ python3 -m models.llama.benchmark \
5. ONNX Runtime, FP32, Microsoft custom export
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
--ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
--model-name meta-llama/Llama-2-7b-hf \
@ -316,7 +299,7 @@ python3 -m models.llama.benchmark \
6. ONNX Runtime, FP16, Microsoft custom export
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
@ -367,7 +350,7 @@ You can profile a variant by adding the `--profile` flag and providing one batch
### Benchmark All
You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example.
```
python3 -m models.llama.benchmark_all \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \
--hf-pt-eager \
--hf-pt-compile \
--hf-ort-dir-path ./llama2-7b-fp16/ \

View file

@ -4,6 +4,8 @@ import argparse
import logging
import os
import shutil
import subprocess
import sys
from itertools import chain
import onnx
@ -408,6 +410,31 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov
only_onnxruntime=False,
)
model_opt.save_model_to_file(output_path, use_external_data_format=True)
# Run symbolic shape inference on optimized model to avoid shape errors during runtime
# Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output.
# After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output.
wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"]
source_cmd = [sys.executable, "../symbolic_shape_infer.py"]
symbolic_shape_infer_args = [
"--input",
output_path,
"--output",
output_path,
"--auto_merge",
"--save_as_external_data",
"--all_tensors_to_one_file",
"--external_data_location",
os.path.basename(output_path) + ".data",
]
file_path = os.path.dirname(__file__)
if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")):
main_cmd = wheel_cmd
else:
main_cmd = source_cmd
subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510
logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
if remove_model:
remove_existing_model(input_path)

View file

@ -21,6 +21,7 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32,
if i == rank % (world_size):
l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir)
l_config.use_cache = True
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
llama = AutoModelForCausalLM.from_pretrained(
location,
use_auth_token=use_auth_token,

View file

@ -1,4 +1,4 @@
git+https://github.com/huggingface/optimum.git
optimum>=1.14.1
transformers>=4.33.2
torch>=2.2.0.dev20230920
onnx>=1.14.0