Add flash attention v2 and INT4 CUDA for LLaMA E2E benchmarking (#20149)

### Description
This PR adds flash attention v2 and support for INT4 CUDA benchmarking
in PyTorch.

### Motivation and Context
The [flash attention v2](https://github.com/Dao-AILab/flash-attention)
algorithm helps improve model performance in PyTorch. Support for INT4
CUDA in PyTorch is done through the
[`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) package.
This commit is contained in:
kunal-vaishnavi 2024-03-29 23:09:37 -07:00 committed by GitHub
parent 00244ea143
commit a0ebd5fee5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 42 additions and 11 deletions

View file

@ -20,6 +20,14 @@
# 4) Install the latest ONNX Runtime version
#
# $ pip install onnxruntime-gpu
#
# 5) Install flash attention v2
#
# $ pip install flash-attn --no-build-isolation
#
# 6) Install bitsandbytes
#
# $ pip install bitsandbytes
from __future__ import annotations
@ -38,22 +46,44 @@ import pandas as pd
import torch
from benchmark_helper import setup_logger
from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import onnxruntime as ort
logger = logging.getLogger(__name__)
def get_model(args):
def get_model(args: argparse.Namespace):
if args.benchmark_type in {"pt-eager", "pt-compile"}:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
use_cache=True,
).to(args.target_device)
model = None
if args.onnx_precision == "int4" and args.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
use_cache=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
max_memory={args.device_id: "80GB"},
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)
model.eval()
if args.benchmark_type == "pt-compile":
@ -223,7 +253,7 @@ def get_args():
parser.add_argument(
"-s",
"--prompt-lengths",
default="32 64 128 256 512",
default="16 64 256 1024",
)
parser.add_argument(
@ -277,6 +307,7 @@ def get_args():
args.prompt_lengths = args.prompt_lengths.split(" ")
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
setattr(args, "onnx_precision", args.precision) # noqa: B010
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)

View file

@ -170,7 +170,7 @@ def get_args(argv: list[str]):
parser.add_argument(
"-m",
"--model_name",
required=True,
required=False,
help="Model name in Hugging Face",
)