mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Add flash attention v2 and INT4 CUDA for LLaMA E2E benchmarking (#20149)
### Description This PR adds flash attention v2 and support for INT4 CUDA benchmarking in PyTorch. ### Motivation and Context The [flash attention v2](https://github.com/Dao-AILab/flash-attention) algorithm helps improve model performance in PyTorch. Support for INT4 CUDA in PyTorch is done through the [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) package.
This commit is contained in:
parent
00244ea143
commit
a0ebd5fee5
2 changed files with 42 additions and 11 deletions
|
|
@ -20,6 +20,14 @@
|
|||
# 4) Install the latest ONNX Runtime version
|
||||
#
|
||||
# $ pip install onnxruntime-gpu
|
||||
#
|
||||
# 5) Install flash attention v2
|
||||
#
|
||||
# $ pip install flash-attn --no-build-isolation
|
||||
#
|
||||
# 6) Install bitsandbytes
|
||||
#
|
||||
# $ pip install bitsandbytes
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -38,22 +46,44 @@ import pandas as pd
|
|||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_model(args):
|
||||
def get_model(args: argparse.Namespace):
|
||||
if args.benchmark_type in {"pt-eager", "pt-compile"}:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
||||
cache_dir=args.cache_dir,
|
||||
torch_dtype=args.torch_dtype,
|
||||
use_auth_token=args.auth,
|
||||
use_cache=True,
|
||||
).to(args.target_device)
|
||||
model = None
|
||||
if args.onnx_precision == "int4" and args.device == "cuda":
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
||||
cache_dir=args.cache_dir,
|
||||
torch_dtype=args.torch_dtype,
|
||||
use_auth_token=args.auth,
|
||||
use_cache=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
quantization_config=bnb_config,
|
||||
max_memory={args.device_id: "80GB"},
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
||||
cache_dir=args.cache_dir,
|
||||
torch_dtype=args.torch_dtype,
|
||||
use_auth_token=args.auth,
|
||||
use_cache=True,
|
||||
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
|
||||
).to(args.target_device)
|
||||
|
||||
model.eval()
|
||||
|
||||
if args.benchmark_type == "pt-compile":
|
||||
|
|
@ -223,7 +253,7 @@ def get_args():
|
|||
parser.add_argument(
|
||||
"-s",
|
||||
"--prompt-lengths",
|
||||
default="32 64 128 256 512",
|
||||
default="16 64 256 1024",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -277,6 +307,7 @@ def get_args():
|
|||
args.prompt_lengths = args.prompt_lengths.split(" ")
|
||||
|
||||
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
||||
setattr(args, "onnx_precision", args.precision) # noqa: B010
|
||||
args.precision = (
|
||||
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ def get_args(argv: list[str]):
|
|||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_name",
|
||||
required=True,
|
||||
required=False,
|
||||
help="Model name in Hugging Face",
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue