Add LLaMA scripts (#17020)

### Description
This PR adds the following scripts for LLaMA:
- LLaMA conversion (support for TorchScript and Dynamo exporters)
- LLaMA parity
- LLaMA benchmark
- LLaMA quantization
- LLaMA integration with [Hugging Face
Optimum](https://github.com/huggingface/optimum)



### Motivation and Context
This PR adds scripts for using LLaMA. There is a [follow-up
PR](https://github.com/microsoft/onnxruntime/pull/17043) for adding
scripts for Whisper.
This commit is contained in:
kunal-vaishnavi 2023-08-22 18:05:11 -07:00 committed by GitHub
parent d3d3dde844
commit edac3ef150
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 2035 additions and 1 deletions

View file

@ -466,6 +466,9 @@ file(GLOB onnxruntime_python_transformers_models_bert_src CONFIGURE_DEPENDS
file(GLOB onnxruntime_python_transformers_models_gpt2_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/gpt2/*.py"
)
file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/llama/*.py"
)
file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py"
)
@ -537,6 +540,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bert
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
@ -628,6 +632,9 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_gpt2_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_llama_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_longformer_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer/

View file

@ -170,7 +170,7 @@ def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
logger.info(f"PyTorch Version:{torch.__version__}")
logger.info(f"Transformers Version:{transformers.__version__}")
logger.info(f"Onnxruntime Version:{onnxruntime.__version__}")
logger.info(f"OnnxRuntime Version:{onnxruntime.__version__}")
# Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
assert version.parse(torch.__version__) >= version.parse("1.10.0")

View file

@ -0,0 +1,187 @@
# LLaMA-2
## Exporting LLaMA-2
There are several ways to export LLaMA-2 models (using LLaMA-2 7B as an example).
### Option 1: from convert_to_onnx
```
# From source:
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers/
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b
```
To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf).
### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx)
Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2.
### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum)
First, log into the Hugging Face CLI in your terminal:
```
$ huggingface-cli login
```
Once authenticated, run the following Python code to export:
```
from optimum.onnxruntime import ORTModelForCausalLM
name = "meta-llama/Llama-2-7b-hf"
model = ORTModelForCausalLM.from_pretrained(
name,
export=True,
use_auth_token=True,
)
model.save_pretrained(name.split("/")[-1] + "-onnx")
```
## Examples of Exporting LLaMA-2
Here are some additional examples for exporting LLaMA-2.
Export Saved Model on Disk
```
# From source:
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b
```
Export for FP16
```
# From source:
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16
```
Export for INT8
```
# From source:
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant
```
Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models.
## Benchmark LLaMA-2
Here are some examples of how you can benchmark LLaMA-2.
Note: In the below examples, `PyTorch` refers to running in PyTorch without `torch.compile` and `PyTorch 2.0` refers to running in PyTorch with `torch.compile`.
### Variants
1. PyTorch (without `torch.compile`), FP32
```
python3 -m models.llama.benchmark \
--benchmark-type hf-pt \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cpu \
--auth
```
2. PyTorch 2.0 (with `torch.compile`), FP16
```
python3 -m models.llama.benchmark \
--benchmark-type hf-pt2 \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cuda \
--auth
```
3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx
```
python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-model-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cpu \
--auth
```
4. Optimum + ONNX Runtime, FP16, export via convert_to_onnx
```
python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-model-path ./llama2-7b-fp16/ \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cuda \
--auth
```
5. Optimum + ONNX Runtime, INT8, export via convert_to_onnx
```
python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-model-path ./llama2-7b-int8/ \
--model-name meta-llama/Llama-2-7b-hf \
--precision int8 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cpu \
--auth
```
6. ONNX Runtime, FP32, Microsoft custom export
```
python3 -m models.llama.benchmark \
--benchmark-type ort \
--ort-model-path llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cpu
```
7. ONNX Runtime, FP16, Microsoft custom export
```
python3 -m models.llama.benchmark \
--benchmark-type ort \
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cuda
```
You can profile a variant by adding the `--profile` flag and providing one batch size and sequence length combination.
### Benchmark All
You can use `benchmark_all.py` to benchmark across various platforms and automatically store the results in a CSV file. Here is an example.
```
python3 -m models.llama.benchmark_all \
--hf-ort-model-path ./llama2-7b-fp16/ \
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cuda
```

View file

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View file

@ -0,0 +1,524 @@
import argparse
import datetime
import gc
import itertools
import logging
import os
import sys
import time
import numpy as np
import psutil
import torch
from benchmark_helper import setup_logger
from llama_inputs import get_msft_sample_inputs, get_sample_inputs, get_sample_with_past_kv_inputs
from optimum.onnxruntime import ORTModelForCausalLM
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import trange
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import onnxruntime as ort
from onnxruntime.transformers.benchmark_helper import measure_memory
logger = logging.getLogger(__name__)
def get_inputs(args: argparse.Namespace):
if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}:
init_inputs = get_sample_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
return_dict=True,
)
iter_inputs = get_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
elif args.benchmark_type == "ort":
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
init_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
use_fp16=args.use_fp16,
)
iter_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=args.sequence_length,
seq_len=1,
use_fp16=args.use_fp16,
)
else:
raise Exception("Unable to auto-detect inputs for provided model")
return init_inputs, iter_inputs
def get_model(args: argparse.Namespace):
model, sess_options = None, None
start_time, end_time = None, None
# There are multiple sources that the model could come from:
# 1) Benchmark LLaMA from unofficial source on Hugging Face
# 2) Benchmark LLaMA from official source on Hugging Face, which requires an authentication token
# 3) Benchmark LLaMA from local download of model
if args.benchmark_type in {"hf-pt", "hf-pt2"}:
source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
start_time = time.time()
model = LlamaForCausalLM.from_pretrained(
source,
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
use_auth_token=args.auth,
use_cache=True,
).to(args.target_device)
end_time = time.time()
if args.benchmark_type == "hf-pt2":
model = torch.compile(model)
elif args.benchmark_type in {"hf-ort", "ort"}:
sess_options = ort.SessionOptions()
sess_options.enable_profiling = args.profile
if args.verbose:
sess_options.log_verbosity_level = 1
sess_options.log_severity_level = 1
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
if args.benchmark_type == "hf-ort":
# Optimum export or convert_to_onnx.py export
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
decoder_file_name = None
decoder_with_past_file_name = None
for filename in os.listdir(args.hf_ort_model_path):
if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
continue
if "decoder_model.onnx" in filename or f"decoder_model_{args.precision}.onnx" in filename:
decoder_file_name = filename
if (
"decoder_with_past_model.onnx" in filename
or f"decoder_with_past_model_{args.precision}.onnx" in filename
):
decoder_with_past_file_name = filename
start_time = time.time()
model = ORTModelForCausalLM.from_pretrained(
args.hf_ort_model_path,
decoder_file_name=decoder_file_name,
decoder_with_past_file_name=decoder_with_past_file_name,
use_auth_token=args.auth,
use_io_binding=(args.device != "cpu"),
provider=provider,
provider_options=provider_options,
session_options=sess_options,
)
end_time = time.time()
if args.benchmark_type == "ort":
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
logger.info(f"Loading model from {args.ort_model_path}")
start_time = time.time()
model = ort.InferenceSession(
args.ort_model_path,
sess_options,
providers=[args.execution_provider],
)
end_time = time.time()
logger.info(f"Loaded model in {end_time - start_time} s")
return model
def time_fn(args, fn, inputs):
# Warm up
warmup_range = (
range(args.warmup_runs)
if args.benchmark_type == "ort"
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
)
if args.verbose:
outputs = fn(inputs)
logger.info(outputs)
for _ in warmup_range:
fn(inputs)
# Benchmark
if args.device != "cpu":
torch.cuda.synchronize()
start_time = time.time()
bench_range = (
range(args.num_runs)
if args.benchmark_type == "ort"
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
fn(inputs)
if args.device != "cpu":
torch.cuda.synchronize()
end_time = time.time()
# Newline print after trange in order to print metrics on new lines without progress bar on same line
if args.benchmark_type != "ort":
logger.info("")
latency = (end_time - start_time) / args.num_runs
throughput = args.batch_size / latency
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Sequence Length: {args.sequence_length}")
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} qps")
return
def profile_fn(args, fn, inputs, inputs_type):
# Filename prefix format:
# "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
if args.benchmark_type in {"hf-pt", "hf-pt2"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
) as prof:
with record_function("model_inference"):
fn(inputs)
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
filename = os.path.join(args.log_folder, f"{prefix}.log")
with open(filename, "w") as f:
f.write(prof_data)
else:
# Profile ORT kernels
fn(inputs)
# Set new log name for ORT profile log generated
filename = f"{prefix}.json"
return filename
def measure_fn(args, fn, inputs):
# Measure CPU usage
pid = os.getpid()
process = psutil.Process(pid)
process.cpu_percent(interval=0.1)
fn(inputs)
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
# Measure memory usage
gc.collect()
torch.cuda.empty_cache()
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
# Flush output so memory usage is printed
sys.stdout.flush()
def run_hf_inference(args, init_inputs, iter_inputs, model):
# Inference steps to measure
def get_logits(inputs):
# Inference pass without decoding
outputs = model(**inputs)
return outputs
# Examples of other inference steps that can be measured:
# To use, uncomment the function and assign it to `generate_fn`
# def get_pred_ids(inputs):
# # Inference pass with predicted token ids generation
# predicted_ids = model.generate(**inputs)
# return predicted_ids
# def gen_and_dec(inputs):
# # Inference pass with generation and decoding
# predicted_ids = get_pred_ids(inputs)
# transcription = []
# for bs in range(args.batch_size):
# for rs in range(args.num_return_sequences):
# transcription.append(
# args.tokenizer.batch_decode(
# predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
# )[0]
# )
# return transcription
generate_fn = get_logits
if args.benchmark_type == "hf-pt2":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(init_inputs)
generate_fn(iter_inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
new_logname = profile_fn(args, generate_fn, iter_inputs, "per-token")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder_with_past.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# PyTorch evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
time_fn(args, generate_fn, init_inputs)
measure_fn(args, generate_fn, init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
time_fn(args, generate_fn, iter_inputs)
measure_fn(args, generate_fn, iter_inputs)
def run_ort_inference(args, init_inputs, iter_inputs, model):
def prepare_ort_inputs(inputs):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(inputs.keys())
missing_inputs = model_inputs - user_inputs
if len(missing_inputs):
logger.error(f"The following model inputs are missing: {missing_inputs}")
raise Exception("There are missing inputs to the model. Please add them and try again.")
# Remove unnecessary inputs from model inputs
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
for unnecessary_input in unnecessary_inputs:
logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
del inputs[unnecessary_input]
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding = model.io_binding()
for k, v in inputs.items():
io_binding.bind_cpu_input(k, v)
for output in model.get_outputs():
io_binding.bind_output(output.name)
return io_binding
return inputs
def with_io_binding(io_binding):
# Inference pass with IO binding
model.run_with_iobinding(io_binding)
def without_io_binding(inputs):
# Inference pass without IO binding
outputs = model.run(None, inputs)
return outputs
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
if args.profile:
ort_init_inputs = prepare_ort_inputs(init_inputs)
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
# Turn profiling off to stop appending to log file
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
# Re-initialize model for new log file instead of appending to old log file
model = get_model(args)
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "per-token")
# Turn profiling off to stop appending to log
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# ORT evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
ort_init_inputs = prepare_ort_inputs(init_inputs)
time_fn(args, generate_fn, ort_init_inputs)
measure_fn(args, generate_fn, ort_init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
time_fn(args, generate_fn, ort_iter_inputs)
measure_fn(args, generate_fn, ort_iter_inputs)
def run_inference(args, init_inputs, iter_inputs, model):
if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}:
run_hf_inference(args, init_inputs, iter_inputs, model)
elif args.benchmark_type == "ort":
run_ort_inference(args, init_inputs, iter_inputs, model)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"]
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=True,
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
)
parser.add_argument(
"-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
)
# Args for choosing the model
parser.add_argument(
"-p",
"--precision",
required=True,
type=str,
default="fp32",
choices=["int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"--hf-pt-model-path",
type=str,
default="",
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
"--hf-ort-model-path",
type=str,
default="",
help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
)
parser.add_argument(
"--ort-model-path",
type=str,
default="",
help="Path to ONNX model",
)
# Args for running and evaluating the model
parser.add_argument(
"-b",
"--batch-sizes",
default="1 2",
)
parser.add_argument(
"-s",
"--sequence-lengths",
default="8 16 32 64 128 256 512",
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda", "rocm"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=10)
parser.add_argument("--seed", type=int, default=2)
# Args for decoding logic
parser.add_argument("--max-length", type=int, default=32)
parser.add_argument("--num-return-sequences", type=int, default=1)
# Args for accessing detailed info
parser.add_argument("--profile", default=False, action="store_true")
parser.add_argument(
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
)
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Set runtime properties
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
args.device = "cuda"
# Check that model paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`"
if args.benchmark_type == "ort":
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
args.batch_sizes = args.batch_sizes.split(" ")
args.sequence_lengths = args.sequence_lengths.split(" ")
# Check that only one (batch_size, sequence_length) combination is set for profiling
if args.profile:
assert (
len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1
), "Please provide only one (batch_size, sequence_length) combination for profiling"
return args
def main():
args = get_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
tokenizer = LlamaTokenizer.from_pretrained(args.model_name)
config = LlamaConfig.from_pretrained(args.model_name)
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
setattr(args, "tokenizer", tokenizer) # noqa: B010
setattr(args, "config", config) # noqa: B010
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
model = get_model(args)
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
setattr(args, "batch_size", int(batch_size)) # noqa: B010
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
init_inputs, iter_inputs = get_inputs(args)
run_inference(args, init_inputs, iter_inputs, model)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,359 @@
import argparse
import datetime
import json
import logging
import os
import subprocess
import torch
from benchmark_helper import setup_logger
logger = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-b",
"--batch-sizes",
type=str,
default="1 2",
)
parser.add_argument(
"-s",
"--sequence-lengths",
type=str,
default="8 16 32 64 128 256 512",
)
parser.add_argument(
"-w",
"--warmup-runs",
type=int,
default=5,
)
parser.add_argument(
"-n",
"--num-runs",
type=int,
default=1000,
)
parser.add_argument(
"--hf-ort-model-path",
type=str,
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
)
parser.add_argument(
"--ort-model-path",
type=str,
help="Path to ONNX model for ORT benchmarking",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name in Hugging Face",
)
parser.add_argument(
"--precision",
type=str,
required=True,
choices=["int8", "fp16", "fp32"],
help="Precision to run model",
)
parser.add_argument(
"--device",
type=str,
required=True,
choices=["cpu", "cuda", "rocm"],
help="Device to benchmark models",
)
parser.add_argument(
"--device-id",
type=int,
default=0,
help="GPU device ID",
)
parser.add_argument(
"--verbose",
default=False,
action="store_true",
help="Print detailed logs",
)
parser.add_argument(
"--timeout",
type=int,
default=10,
help="Number of mins to attempt the benchmark before moving on",
)
args = parser.parse_args()
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
log_folder_name = f"./{args.model_size}_{args.precision}"
setattr(args, "log_folder", log_folder_name) # noqa: B010
os.makedirs(args.log_folder, exist_ok=True)
# Convert timeout value to secs
args.timeout *= 60
return args
def process_log_file(device_id, log_file, base_results):
entries = []
batch_size, sequence_length, step = None, None, None
latency_s, latency_ms, throughput, memory = None, None, None, None
batch_pattern = "Batch Size: "
sequence_pattern = "Sequence Length: "
prompt_step_pattern = "to get past_key_values"
per_token_step_pattern = "with past_key_values"
latency_pattern = "Latency: "
throughput_pattern = "Throughput: "
memory_pattern = "peak="
with open(log_file) as f:
for input_line in f:
line = input_line.replace("\n", "")
if batch_pattern in line:
batch_size = int(line[len(batch_pattern) :])
elif sequence_pattern in line:
sequence_length = int(line[len(sequence_pattern) :])
elif prompt_step_pattern in line:
step = "prompt"
elif per_token_step_pattern in line:
step = "per-token"
elif latency_pattern in line:
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
if step == "prompt":
latency_s /= sequence_length
latency_ms = latency_s * 1000
elif throughput_pattern in line:
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
elif memory_pattern in line:
if "CPU" in line:
# Example format for log entry:
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
else:
# Example format for log entry:
# GPU memory usage: before=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 69637.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}] peak=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 73861.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}]
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
usage = json.loads(peak)[device_id]["max_used_MB"]
memory = float(usage) / 1000
# Append log entry to list of entries
entry = base_results + [ # noqa: RUF005
batch_size,
sequence_length,
step,
latency_s,
latency_ms,
throughput,
memory,
]
entries.append(entry)
return entries
def save_results(results, filename):
import pandas as pd
df = pd.DataFrame(
results,
columns=[
"Engine",
"Precision",
"Device",
"Batch Size",
"Sequence Length",
"Step",
"Latency (s)",
"Latency (ms)",
"Throughput (qps)",
"Memory (GB)",
],
)
# Set column types
df["Batch Size"] = df["Batch Size"].astype("int")
df["Sequence Length"] = df["Sequence Length"].astype("int")
df["Latency (s)"] = df["Latency (s)"].astype("float")
df["Latency (ms)"] = df["Latency (ms)"].astype("float")
df["Throughput (qps)"] = df["Throughput (qps)"].astype("float")
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
df.to_csv(filename, index=False)
logger.info(f"Results saved in {filename}!")
def benchmark(args, benchmark_cmd, engine):
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
log_path = os.path.join(args.log_folder, log_filename)
with open(log_path, "w") as log_file:
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
try:
process.wait(args.timeout)
except subprocess.TimeoutExpired:
process.kill()
# Create entries for csv
logger.info("Gathering data from log files...")
base_results = [engine, args.precision, args.device]
results = process_log_file(args.device_id, log_path, base_results)
return results
def main():
args = get_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
all_results = []
# Benchmark PyTorch without torch.compile
benchmark_cmd = [
"python3",
"benchmark.py",
"--benchmark-type",
"hf-pt",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--auth",
]
logger.info("Benchmark PyTorch without torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch")
all_results.extend(results)
# Benchmark PyTorch with torch.compile
benchmark_cmd = [
"python3",
"benchmark.py",
"--benchmark-type",
"hf-pt2",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--auth",
]
logger.info("Benchmark PyTorch with torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-2")
all_results.extend(results)
# Benchmark Optimum + ONNX Runtime
if args.hf_ort_model_path:
benchmark_cmd = [
"python3",
"benchmark.py",
"--benchmark-type",
"hf-ort",
"--hf-ort-model-path",
args.hf_ort_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--auth",
]
logger.info("Benchmark Optimum + ONNX Runtime")
results = benchmark(args, benchmark_cmd, "pytorch-ort")
all_results.extend(results)
# Benchmark ONNX Runtime
if args.ort_model_path:
benchmark_cmd = [
"python3",
"benchmark.py",
"--benchmark-type",
"ort",
"--ort-model-path",
args.ort_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
]
logger.info("Benchmark ONNX Runtime")
results = benchmark(args, benchmark_cmd, "onnxruntime")
all_results.extend(results)
csv_file = f"{args.model_size}_{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_results, os.path.join(args.log_folder, csv_file))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,576 @@
import argparse
import logging
import os
import tempfile
from itertools import chain
from typing import List
import onnx
import torch
from benchmark_helper import Precision, prepare_environment, setup_logger
from llama_inputs import get_sample_inputs, get_sample_with_past_kv_inputs
from llama_parity import main as parity_check
from onnx_model import OnnxModel
from transformers import LlamaConfig, LlamaForCausalLM
from onnxruntime import quantization as ort_quantization
logger = logging.getLogger("")
def get_model_dynamic_axes(input_names: List[str], output_names: List[str]):
dynamic_axes = {}
for name in input_names + output_names:
if name in input_names:
# shape is (batch_size, sequence_length)
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
elif name == "logits":
# shape is (batch_size, sequence_length, vocab_size)
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
elif "present" in name:
# shape is (batch_size, num_heads, sequence_length, head_size)
dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
else:
raise Exception("Unknown input or output name found")
return dynamic_axes
def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: List[str]):
dynamic_axes = {}
for name in input_names + output_names:
if name in {"input_ids", "position_ids"}:
# shape is (batch_size, 1)
dynamic_axes[name] = {0: "batch_size"}
elif name == "attention_mask":
# shape is (batch_size, past_sequence_length + 1)
dynamic_axes[name] = {0: "batch_size", 1: "past_sequence_length + 1"}
elif "past" in name:
# shape is (batch_size, num_heads, past_sequence_length, head_size)
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
elif name == "logits":
# shape is (batch_size, 1, vocab_size)
dynamic_axes[name] = {0: "batch_size"}
elif "present" in name:
# shape is (batch_size, num_heads, past_sequence_length + 1, head_size)
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length + 1"}
else:
raise Exception("Unknown input or output name found")
return dynamic_axes
def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: str):
onnx.save(
onnx_model,
output_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=data_path,
size_threshold=1024,
convert_attribute=False,
)
# Notes:
# 1) Dynamo export will not work automatically until this issue is resolved: https://github.com/microsoft/onnxscript/issues/493
#
# 2) Dynamo export will run manually if you set the ONNX file path to the same path that you use to save the model after export.
# In other words, the value of `temp_path` should be set as the ONNX file path. You can open the issue in your browser to find
# the location in ONNX Script where you have to make this change.
#
# Once the issue is resolved, we hope to modify the code below as follows for each export.
#
# Before:
# temp_dir = args.output
# temp_path = os.path.join(temp_dir, "temp.onnx")
# ...
# ...
# ...
# del onnx_model
# os.system(f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}")
#
#
# After:
# temp_dir = tempfile.TemporaryDirectory()
# temp_path = os.path.join(temp_dir.name, "temp.onnx")
# ...
# ...
# ...
# del onnx_model
# temp_dir.cleanup()
#
def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
from torch._dynamo import config
config.capture_scalar_outputs = True
# Dummy values for export
batch_size, sequence_length = 2, 8
device = torch.device("cpu")
# Export decoder_model.onnx
input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length)
temp_dir = args.output # tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.dynamo_export(
llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(temp_path)
# Check decoder_model.onnx and save all external data to one file
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
) # temp_dir.cleanup()
# Export decoder_with_past_model.onnx
input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs(
l_config, device, batch_size, sequence_length
)
temp_dir = args.output # tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.dynamo_export(
llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(temp_path)
# Check decoder_with_past_model.onnx and save all external data to one file
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_with_past_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
) # temp_dir.cleanup()
logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")
def run_torchscript_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
# Dummy values for export
batch_size, sequence_length = 2, 8
device = torch.device("cpu")
# Export decoder_model.onnx
decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length)
input_names = ["input_ids", "attention_mask", "position_ids"]
output_names = [
"logits",
*list(
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
),
]
dynamic_axes = get_model_dynamic_axes(input_names, output_names)
temp_dir = tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.export(
llama,
args=decoder_inputs,
f=temp_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=13,
do_constant_folding=True,
verbose=args.verbose,
)
# Check decoder_model.onnx and save all external data to one file
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(
onnx_model,
output_path,
f"{args.model_name}_decoder_model_fp32.onnx.data",
)
del onnx_model
temp_dir.cleanup()
# Export decoder_with_past_model.onnx
decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length)
input_names = [
"input_ids",
"attention_mask",
"position_ids",
*list(
chain.from_iterable(
(f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers)
)
),
]
output_names = [
"logits",
*list(
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
),
]
dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names)
temp_dir = tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.export(
llama,
args=decoder_with_past_inputs,
f=temp_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=13,
do_constant_folding=True,
verbose=args.verbose,
)
# Check decoder_with_past_model.onnx and save all external data to one file
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(
onnx_model,
output_path,
f"{args.model_name}_decoder_with_past_model_fp32.onnx.data",
)
del onnx_model
temp_dir.cleanup()
logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!")
def remove_existing_files(output_path: str):
for filename in os.listdir(output_path):
filepath = os.path.join(output_path, filename)
if ".onnx" in filename or ".onnx.data" in filename:
os.remove(filepath)
logger.warning(f"Removing {filepath}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=True,
help="Model name in Hugging Face",
)
parser.add_argument(
"-i",
"--input",
required=False,
default=os.path.join("."),
help="Directory path to PyTorch model and associated files if saved on disk",
)
parser.add_argument(
"-o",
"--output",
required=False,
default=os.path.join(".", "llama_onnx_models"),
help="Directory path to save exported model files in",
)
parser.add_argument(
"-p",
"--precision",
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
help="Precision to export model in",
)
parser.add_argument(
"-e",
"--execution_provider",
required=False,
default="cpu",
choices=["cpu", "cuda", "rocm"],
help="Execution provider to verify parity with",
)
parser.add_argument(
"-q",
"--quantization_method",
default="",
choices=["smooth_quant", "quantize_dynamic"],
help="Run a specific quantization algorithm. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.",
)
smooth_quant_group = parser.add_argument_group("smooth_quant")
smooth_quant_group.add_argument(
"--smooth_quant_alpha",
required=False,
default=0.8,
type=float,
help="Strength to control migration difficulty from activation to weights. Default is 0.8 to match value \
used in original paper for LLaMA. Paper recommends using values in [0.4, 0.6] range. \
Link to paper: https://arxiv.org/pdf/2211.10438.pdf",
)
smooth_quant_group.add_argument(
"--smooth_quant_dataset",
required=False,
default="NeelNanda/pile-10k",
help="Path to dataset for calibration during quantization",
)
smooth_quant_group.add_argument(
"--pad_max",
required=False,
default=196,
type=int,
help="Max padding size",
)
smooth_quant_group.add_argument(
"--calibration_sampling_size",
required=False,
type=int,
default=8,
help="Calibration sampling size for quantization config",
)
smooth_quant_group.add_argument(
"--nc_workspace",
required=False,
type=str,
default=os.path.join(".", "nc_workspace"),
help="Workspace to save intermediate files generated by Intel's Neural Compressor package.",
)
quantize_dynamic_group = parser.add_argument_group("quantize_dynamic")
quantize_dynamic_group.add_argument(
"--quantize_embedding_layer",
required=False,
action="store_true",
help="Quantize MatMul, GEMM, and Gather.",
)
quantize_dynamic_group.set_defaults(quantize_embedding_layer=False)
quantize_dynamic_group.add_argument(
"--quantize_per_channel",
required=False,
action="store_true",
help="Quantize weights per each channel.",
)
quantize_dynamic_group.set_defaults(quantize_per_channel=False)
quantize_dynamic_group.add_argument(
"--quantize_reduce_range",
required=False,
action="store_true",
help="Quantize weights with 7 bits.",
)
quantize_dynamic_group.set_defaults(quantize_reduce_range=False)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Print verbose logs",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"-d",
"--use_dynamo_export",
action="store_true",
help="Use the new Dynamo exporter instead of the old TorchScript exporter",
)
parser.set_defaults(use_dynamo_export=False)
args = parser.parse_args()
return args
def main():
args = get_args()
setup_logger(args.verbose)
prepare_environment(args.input, args.output, args.execution_provider != "cpu")
remove_existing_files(args.output)
logger.info(f"Arguments: {args}")
# Load model and config
use_auth_token = args.input == os.path.join(".")
setattr(args, "use_auth_token", use_auth_token) # noqa: B010
l_config = LlamaConfig.from_pretrained(
args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token
)
llama = LlamaForCausalLM.from_pretrained(
args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token, use_cache=True
)
original_model_name = args.model_name
setattr(args, "original_model_name", original_model_name) # noqa: B010
args.model_name = args.model_name.split("/")[-1]
# Export to ONNX
if args.use_dynamo_export:
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
logger.warning(
"Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
)
logger.warning(
"Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
)
run_dynamo_export(args, l_config, llama)
else:
run_torchscript_export(args, l_config, llama)
# Change precision of exported models if not FP32
decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
decoder_with_past_model_fp32_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx"
)
if args.precision == Precision.FLOAT16:
# Convert decoder_model.onnx to FP16
decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx")
model = OnnxModel(onnx.load_model(decoder_model_fp32_path, load_external_data=True))
model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"])
model.save_model_to_file(decoder_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True)
del model
# Convert decoder_with_past_model.onnx to FP16
decoder_with_past_model_fp16_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx"
)
model = OnnxModel(onnx.load_model(decoder_with_past_model_fp32_path, load_external_data=True))
model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"])
model.save_model_to_file(
decoder_with_past_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True
)
del model
elif args.precision == Precision.INT8:
decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx")
decoder_with_past_model_int8_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx"
)
if args.quantization_method == "smooth_quant":
from neural_compressor import PostTrainingQuantConfig
from neural_compressor import quantization as intel_quantization
from neural_compressor import set_workspace
from onnx.external_data_helper import load_external_data_for_model
from quant_kv_dataloader import QuantKVDataLoader
set_workspace(args.nc_workspace)
quantization_config = PostTrainingQuantConfig(
calibration_sampling_size=[args.calibration_sampling_size],
recipes={
"optypes_to_exclude_output_quant": ["MatMul"],
"smooth_quant": args.smooth_quant,
"smooth_quant_args": {"alpha": args.smooth_quant_alpha},
},
op_type_dict={
"^((?!(MatMul|Gather|Conv)).)*$": {
"weight": {"dtype": ["fp32"]},
"activation": {"dtype": ["fp32"]},
}
},
)
# Convert decoder_model.onnx to INT8
decoder_model_int8 = intel_quantization.fit(
decoder_model_fp32_path,
quantization_config,
calib_dataloader=QuantKVDataLoader(args),
)
load_external_data_for_model(
decoder_model_int8._model,
os.path.split(decoder_model_int8._model_path)[0],
)
save_onnx_model(
decoder_model_int8._model,
decoder_model_int8_path,
f"{args.model_name}_decoder_model_int8.onnx.data",
)
del decoder_model_int8
# Convert decoder_with_past_model.onnx to INT8
decoder_with_past_model_int8 = intel_quantization.fit(
decoder_with_past_model_fp32_path,
quantization_config,
calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path),
)
load_external_data_for_model(
decoder_with_past_model_int8._model,
os.path.split(decoder_with_past_model_int8._model_path)[0],
)
save_onnx_model(
decoder_with_past_model_int8._model,
decoder_with_past_model_int8_path,
f"{args.model_name}_decoder_with_past_model_int8.onnx.data",
)
del decoder_with_past_model_int8
logger.info(f"Removing {args.nc_workspace}")
os.system(f"rm -R {args.nc_workspace}")
elif args.quantization_method == "quantize_dynamic":
logger.warning(
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
)
# Convert decoder_model.onnx to INT8
ort_quantization.quantize_dynamic(
decoder_model_fp32_path,
decoder_model_int8_path,
op_types_to_quantize=["MatMul", "Gemm", "Gather"]
if args.quantize_embedding_layer
else ["MatMul", "Gemm"],
per_channel=args.quantize_per_channel,
reduce_range=args.quantize_reduce_range,
use_external_data_format=True,
extra_options={"MatMulConstBOnly": True},
)
# Convert decoder_with_past_model.onnx to INT8
ort_quantization.quantize_dynamic(
decoder_with_past_model_fp32_path,
decoder_with_past_model_int8_path,
op_types_to_quantize=["MatMul", "Gemm", "Gather"]
if args.quantize_embedding_layer
else ["MatMul", "Gemm"],
per_channel=args.quantize_per_channel,
reduce_range=args.quantize_reduce_range,
use_external_data_format=True,
extra_options={"MatMulConstBOnly": True},
)
else:
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
# Verify parity on all saved ONNX models
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
logger.info("Verifying parity on all ONNX models created")
for filename in os.listdir(args.output):
if ".data" in filename or ".onnx" not in filename:
continue
precision = filename[filename.rfind("_") + 1 : filename.find(".onnx")]
parity_cmd = ["-m", f"{original_model_name}", "-o", f"{os.path.join(args.output, filename)}", "-fp", precision]
if "with_past" in filename:
parity_cmd.append("--use_past_kv")
parity_check(parity_cmd)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,119 @@
from typing import List, Tuple
import numpy as np
import torch
from transformers import LlamaConfig
# Get position_ids from attention_mask
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
position_ids = attention_mask.long().cumsum(-1) - 1
if use_past_kv:
position_ids = position_ids[:, -1].unsqueeze(-1)
return position_ids
# Inputs for first pass to get initial past_key_values
def get_sample_inputs(
config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, return_dict: bool = False
):
input_ids = torch.randint(
low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64
)
attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len)
position_ids = get_position_ids(attention_mask, use_past_kv=False)
if not return_dict:
return (input_ids, attention_mask, position_ids)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return inputs
# Inputs for subsequent passes with past_key_values
def get_sample_with_past_kv_inputs(
config: LlamaConfig,
device: torch.device,
batch_size: int,
past_seq_len: int,
use_fp16: bool = False,
return_dict: bool = False,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64)
# position_ids is of shape (batch_size, 1)
position_ids = get_position_ids(attention_mask, use_past_kv=True)
past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16)
if not return_dict:
return (input_ids, attention_mask, position_ids, past_kv)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_kv,
}
return inputs
# Create past_key_values
def get_sample_past_kv_inputs(
config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool
):
num_heads, head_size = config.num_attention_heads, config.hidden_size // config.num_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
)
for _ in range(config.num_hidden_layers)
]
return past_kv
# Convert list of past_kv to dict of past_key and past_value
def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool):
past_kv = {}
np_dtype = np.float16 if use_fp16 else np.float32
for i, (past_k, past_v) in enumerate(past_key_values):
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype)
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype)
return past_kv
# Format PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(pt_inputs: dict, use_fp16: bool):
ort_inputs = {}
for k, v in pt_inputs.items():
if k == "past_key_values":
ort_inputs.update(flatten_past_kv_inputs(v, use_fp16))
else:
ort_inputs[k] = v.detach().cpu().numpy()
return ort_inputs
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
def get_msft_sample_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool):
np_dtype = np.float16 if use_fp16 else np.float32
head_size = config.hidden_size // config.num_attention_heads
max_seq_len = 2048
ort_inputs = {
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
"attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
"k_cache": np.random.rand(
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
).astype(np_dtype),
"v_cache": np.random.rand(
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
).astype(np_dtype),
"pos": np.array(past_seq_len, dtype=np.int64),
}
return ort_inputs

View file

@ -0,0 +1,132 @@
import argparse
import logging
import os
from typing import List
import numpy as np
import torch
from benchmark_helper import create_onnxruntime_session, setup_logger
from llama_inputs import convert_inputs_for_ort, get_sample_inputs, get_sample_with_past_kv_inputs
from transformers import LlamaConfig, LlamaForCausalLM
logger = logging.getLogger("")
def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM):
# Dummy values for parity
batch_size, sequence_length = 2, 8
device = torch.device("cpu")
# Run inference with PyTorch
inputs = (
get_sample_inputs(config, device, batch_size, sequence_length, return_dict=True)
if not args.use_past_kv
else get_sample_with_past_kv_inputs(
config, device, batch_size, sequence_length, use_fp16=(args.precision == "fp16"), return_dict=True
)
)
pt_outputs = pt_model(**inputs).logits.detach().cpu().numpy()
# Run inference with ORT
inputs = convert_inputs_for_ort(inputs, use_fp16=(args.precision == "fp16"))
ort_model = create_onnxruntime_session(
args.onnx_model_path,
args.execution_provider != "cpu", # use_gpu
provider=args.execution_provider,
verbose=args.verbose,
)
ort_outputs = ort_model.run(None, inputs)[0]
# Compare PyTorch and ONNX Runtime accuracy
tol = 1e-3 if args.precision == "fp32" else 1e-2 if args.precision == "fp16" else 1e2
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
if not parity:
logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
def get_args(argv: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=True,
help="Model name in Hugging Face",
)
parser.add_argument(
"-t",
"--torch_model_directory",
required=False,
default=os.path.join("."),
help="Path to folder containing PyTorch model and associated files if saved on disk",
)
parser.add_argument(
"-o",
"--onnx_model_path",
required=True,
default=os.path.join("."),
help="Path to ONNX model (with external data files saved in the same folder as the model)",
)
parser.add_argument(
"-ep",
"--execution_provider",
required=False,
default="cpu",
choices=["cpu", "cuda", "rocm"],
help="Execution provider to verify parity with",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Print verbose logs",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"-p",
"--use_past_kv",
action="store_true",
help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.",
)
parser.set_defaults(use_past_kv=False)
parser.add_argument(
"-fp",
"--precision",
required=True,
choices=["int8", "fp16", "fp32"],
help="Precision of model",
)
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
return args
def main(argv: List[str] = []): # noqa: B006
args = get_args(argv)
setup_logger(args.verbose)
logger.info(f"Arguments: {args}")
# Load model and config
use_auth_token = args.torch_model_directory == os.path.join(".")
location = args.model_name if use_auth_token else args.torch_model_directory
config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token)
llama = LlamaForCausalLM.from_pretrained(
location,
torch_dtype=(torch.float16 if args.precision == "fp16" else torch.float32),
use_auth_token=use_auth_token,
use_cache=True,
)
verify_parity(args, config, llama)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,103 @@
import argparse
import numpy as np
import torch
from benchmark_helper import create_onnxruntime_session
from datasets import load_dataset
from llama_inputs import get_position_ids
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from transformers import LlamaTokenizer
class QuantKVDataLoader:
def __init__(self, args: argparse.Namespace, onnx_model_path: str = ""):
self.batch_size = 1
self.pad_max = args.pad_max
tokenizer = LlamaTokenizer.from_pretrained(args.original_model_name, use_auth_token=args.use_auth_token)
dataset = load_dataset(args.smooth_quant_dataset, split="train")
dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
self.dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=self.collate_batch,
)
self.decoder_model = (
create_onnxruntime_session(
onnx_model_path,
args.execution_provider != "cpu", # use_gpu
provider=args.execution_provider,
verbose=args.verbose,
)
if onnx_model_path
else None
)
def collate_batch(self, batch):
input_ids_batched = []
attention_mask_batched = []
position_ids_batched = []
labels = []
for text in batch:
# Set inputs for model
input_ids = text["input_ids"]
attention_mask = torch.ones(len(input_ids))
position_ids = get_position_ids(attention_mask, use_past_kv=False)
label = len(input_ids) - 1
# Pad input data because all model inputs must have same shape
pad_len = self.pad_max - input_ids.shape[0]
input_ids = pad(input_ids, (0, pad_len), value=1)
attention_mask = pad(attention_mask, (0, pad_len), value=0)
position_ids = pad(position_ids, (0, pad_len), value=0)
input_ids_batched.append(input_ids)
attention_mask_batched.append(attention_mask)
position_ids_batched.append(position_ids)
labels.append(label)
input_ids_batched = torch.vstack(input_ids_batched)
attention_mask_batched = torch.vstack(attention_mask_batched)
position_ids_batched = torch.vstack(position_ids_batched)
labels = torch.tensor(labels)
return (input_ids_batched, attention_mask_batched, position_ids_batched), labels
def __iter__(self):
try:
for (input_ids, attention_mask, position_ids), labels in self.dataloader:
# Inputs for decoder_model.onnx
inputs = {
"input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
"attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64),
"position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
}
label = labels.detach().cpu().numpy()
if self.decoder_model is not None:
# Run decoder_model.onnx to get inputs for decoder_with_past_model.onnx
outputs = self.decoder_model.run(None, inputs)
for i in range(int((len(outputs) - 1) / 2)):
inputs[f"past_key_values.{i}.key"] = outputs[i * 2 + 1]
inputs[f"past_key_values.{i}.value"] = outputs[i * 2 + 2]
past_sequence_length = inputs["past_key_values.0.key"].shape[2]
inputs["input_ids"] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype(np.int64)
attn_mask_torch = torch.ones((self.batch_size, past_sequence_length + 1), dtype=torch.int64)
inputs["attention_mask"] = attn_mask_torch.detach().cpu().numpy().astype(np.int64)
inputs["position_ids"] = (
get_position_ids(attn_mask_torch, use_past_kv=True).detach().cpu().numpy().astype(np.int64)
)
# Yield (inputs, label) tuple for Intel's Neural Compressor:
# https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62
yield (inputs, label)
except StopIteration:
return

View file

@ -0,0 +1,3 @@
-r requirements.txt
torch>=2.0.1
onnxruntime>=1.16.0

View file

@ -0,0 +1,4 @@
-r requirements.txt
# Please manually install torch>=2.0.1 with CUDA enabled for the CUDA version installed in your system.
# Instructions can be found here: https://pytorch.org/get-started/locally/
onnxruntime-gpu>=1.16.0

View file

@ -0,0 +1,2 @@
-r requirements-cpu.txt
neural-compressor>=2.2.1

View file

@ -0,0 +1,5 @@
git+https://github.com/kunal-vaishnavi/optimum.git@kvaishnavi/llama-add-position-ids
transformers>=4.28.1
onnx>=1.14.0
datasets>=2.8.0
protobuf==3.20.2

View file

@ -470,6 +470,7 @@ packages = [
"onnxruntime.transformers.models.bart",
"onnxruntime.transformers.models.bert",
"onnxruntime.transformers.models.gpt2",
"onnxruntime.transformers.models.llama",
"onnxruntime.transformers.models.longformer",
"onnxruntime.transformers.models.t5",
"onnxruntime.transformers.models.stable_diffusion",