mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add LLaMA scripts (#17020)
### Description This PR adds the following scripts for LLaMA: - LLaMA conversion (support for TorchScript and Dynamo exporters) - LLaMA parity - LLaMA benchmark - LLaMA quantization - LLaMA integration with [Hugging Face Optimum](https://github.com/huggingface/optimum) ### Motivation and Context This PR adds scripts for using LLaMA. There is a [follow-up PR](https://github.com/microsoft/onnxruntime/pull/17043) for adding scripts for Whisper.
This commit is contained in:
parent
d3d3dde844
commit
edac3ef150
15 changed files with 2035 additions and 1 deletions
|
|
@ -466,6 +466,9 @@ file(GLOB onnxruntime_python_transformers_models_bert_src CONFIGURE_DEPENDS
|
|||
file(GLOB onnxruntime_python_transformers_models_gpt2_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/gpt2/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/llama/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py"
|
||||
)
|
||||
|
|
@ -537,6 +540,7 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bert
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
|
||||
|
|
@ -628,6 +632,9 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_gpt2_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_llama_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_longformer_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer/
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
|
|||
|
||||
logger.info(f"PyTorch Version:{torch.__version__}")
|
||||
logger.info(f"Transformers Version:{transformers.__version__}")
|
||||
logger.info(f"Onnxruntime Version:{onnxruntime.__version__}")
|
||||
logger.info(f"OnnxRuntime Version:{onnxruntime.__version__}")
|
||||
|
||||
# Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
|
||||
assert version.parse(torch.__version__) >= version.parse("1.10.0")
|
||||
|
|
|
|||
187
onnxruntime/python/tools/transformers/models/llama/README.md
Normal file
187
onnxruntime/python/tools/transformers/models/llama/README.md
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
# LLaMA-2
|
||||
|
||||
## Exporting LLaMA-2
|
||||
|
||||
There are several ways to export LLaMA-2 models (using LLaMA-2 7B as an example).
|
||||
|
||||
### Option 1: from convert_to_onnx
|
||||
```
|
||||
# From source:
|
||||
$ git clone https://github.com/microsoft/onnxruntime
|
||||
$ cd onnxruntime/onnxruntime/python/tools/transformers/
|
||||
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b
|
||||
|
||||
# From wheel:
|
||||
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b
|
||||
```
|
||||
|
||||
To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf).
|
||||
|
||||
### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx)
|
||||
|
||||
Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2.
|
||||
|
||||
### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum)
|
||||
|
||||
First, log into the Hugging Face CLI in your terminal:
|
||||
|
||||
```
|
||||
$ huggingface-cli login
|
||||
```
|
||||
|
||||
Once authenticated, run the following Python code to export:
|
||||
|
||||
```
|
||||
from optimum.onnxruntime import ORTModelForCausalLM
|
||||
|
||||
name = "meta-llama/Llama-2-7b-hf"
|
||||
model = ORTModelForCausalLM.from_pretrained(
|
||||
name,
|
||||
export=True,
|
||||
use_auth_token=True,
|
||||
)
|
||||
model.save_pretrained(name.split("/")[-1] + "-onnx")
|
||||
```
|
||||
|
||||
## Examples of Exporting LLaMA-2
|
||||
|
||||
Here are some additional examples for exporting LLaMA-2.
|
||||
|
||||
Export Saved Model on Disk
|
||||
```
|
||||
# From source:
|
||||
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b
|
||||
|
||||
# From wheel:
|
||||
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b
|
||||
```
|
||||
|
||||
Export for FP16
|
||||
```
|
||||
# From source:
|
||||
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16
|
||||
|
||||
# From wheel:
|
||||
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16
|
||||
```
|
||||
|
||||
Export for INT8
|
||||
```
|
||||
# From source:
|
||||
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant
|
||||
|
||||
# From wheel:
|
||||
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant
|
||||
```
|
||||
|
||||
Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models.
|
||||
|
||||
## Benchmark LLaMA-2
|
||||
|
||||
Here are some examples of how you can benchmark LLaMA-2.
|
||||
|
||||
Note: In the below examples, `PyTorch` refers to running in PyTorch without `torch.compile` and `PyTorch 2.0` refers to running in PyTorch with `torch.compile`.
|
||||
|
||||
### Variants
|
||||
|
||||
1. PyTorch (without `torch.compile`), FP32
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
--benchmark-type hf-pt \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp32 \
|
||||
--batch-sizes "1 2" \
|
||||
--sequence-lengths "8 16" \
|
||||
--device cpu \
|
||||
--auth
|
||||
```
|
||||
|
||||
2. PyTorch 2.0 (with `torch.compile`), FP16
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
--benchmark-type hf-pt2 \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp16 \
|
||||
--batch-sizes "1 2" \
|
||||
--sequence-lengths "8 16" \
|
||||
--device cuda \
|
||||
--auth
|
||||
```
|
||||
|
||||
3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
--benchmark-type hf-ort \
|
||||
--hf-ort-model-path ./Llama-2-7b-hf-onnx/ \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp32 \
|
||||
--batch-sizes "1 2" \
|
||||
--sequence-lengths "8 16" \
|
||||
--device cpu \
|
||||
--auth
|
||||
```
|
||||
|
||||
4. Optimum + ONNX Runtime, FP16, export via convert_to_onnx
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
--benchmark-type hf-ort \
|
||||
--hf-ort-model-path ./llama2-7b-fp16/ \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp16 \
|
||||
--batch-sizes "1 2" \
|
||||
--sequence-lengths "8 16" \
|
||||
--device cuda \
|
||||
--auth
|
||||
```
|
||||
|
||||
5. Optimum + ONNX Runtime, INT8, export via convert_to_onnx
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
--benchmark-type hf-ort \
|
||||
--hf-ort-model-path ./llama2-7b-int8/ \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision int8 \
|
||||
--batch-sizes "1 2" \
|
||||
--sequence-lengths "8 16" \
|
||||
--device cpu \
|
||||
--auth
|
||||
```
|
||||
|
||||
6. ONNX Runtime, FP32, Microsoft custom export
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
--benchmark-type ort \
|
||||
--ort-model-path llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp32 \
|
||||
--batch-sizes "1 2" \
|
||||
--sequence-lengths "8 16" \
|
||||
--device cpu
|
||||
```
|
||||
|
||||
7. ONNX Runtime, FP16, Microsoft custom export
|
||||
```
|
||||
python3 -m models.llama.benchmark \
|
||||
--benchmark-type ort \
|
||||
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp16 \
|
||||
--batch-sizes "1 2" \
|
||||
--sequence-lengths "8 16" \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
You can profile a variant by adding the `--profile` flag and providing one batch size and sequence length combination.
|
||||
|
||||
### Benchmark All
|
||||
You can use `benchmark_all.py` to benchmark across various platforms and automatically store the results in a CSV file. Here is an example.
|
||||
```
|
||||
python3 -m models.llama.benchmark_all \
|
||||
--hf-ort-model-path ./llama2-7b-fp16/ \
|
||||
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
|
||||
--model-name meta-llama/Llama-2-7b-hf \
|
||||
--precision fp16 \
|
||||
--batch-sizes "1 2" \
|
||||
--sequence-lengths "8 16" \
|
||||
--device cuda
|
||||
```
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
524
onnxruntime/python/tools/transformers/models/llama/benchmark.py
Normal file
524
onnxruntime/python/tools/transformers/models/llama/benchmark.py
Normal file
|
|
@ -0,0 +1,524 @@
|
|||
import argparse
|
||||
import datetime
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from llama_inputs import get_msft_sample_inputs, get_sample_inputs, get_sample_with_past_kv_inputs
|
||||
from optimum.onnxruntime import ORTModelForCausalLM
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from tqdm import trange
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.transformers.benchmark_helper import measure_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_inputs(args: argparse.Namespace):
|
||||
if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}:
|
||||
init_inputs = get_sample_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
args.sequence_length,
|
||||
return_dict=True,
|
||||
)
|
||||
iter_inputs = get_sample_with_past_kv_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
args.sequence_length,
|
||||
use_fp16=args.use_fp16,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
elif args.benchmark_type == "ort":
|
||||
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
||||
init_inputs = get_msft_sample_inputs(
|
||||
args.config,
|
||||
args.batch_size,
|
||||
past_seq_len=0,
|
||||
seq_len=args.sequence_length,
|
||||
use_fp16=args.use_fp16,
|
||||
)
|
||||
iter_inputs = get_msft_sample_inputs(
|
||||
args.config,
|
||||
args.batch_size,
|
||||
past_seq_len=args.sequence_length,
|
||||
seq_len=1,
|
||||
use_fp16=args.use_fp16,
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Unable to auto-detect inputs for provided model")
|
||||
|
||||
return init_inputs, iter_inputs
|
||||
|
||||
|
||||
def get_model(args: argparse.Namespace):
|
||||
model, sess_options = None, None
|
||||
start_time, end_time = None, None
|
||||
|
||||
# There are multiple sources that the model could come from:
|
||||
# 1) Benchmark LLaMA from unofficial source on Hugging Face
|
||||
# 2) Benchmark LLaMA from official source on Hugging Face, which requires an authentication token
|
||||
# 3) Benchmark LLaMA from local download of model
|
||||
|
||||
if args.benchmark_type in {"hf-pt", "hf-pt2"}:
|
||||
source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
|
||||
start_time = time.time()
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
source,
|
||||
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
|
||||
use_auth_token=args.auth,
|
||||
use_cache=True,
|
||||
).to(args.target_device)
|
||||
end_time = time.time()
|
||||
|
||||
if args.benchmark_type == "hf-pt2":
|
||||
model = torch.compile(model)
|
||||
|
||||
elif args.benchmark_type in {"hf-ort", "ort"}:
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.enable_profiling = args.profile
|
||||
if args.verbose:
|
||||
sess_options.log_verbosity_level = 1
|
||||
sess_options.log_severity_level = 1
|
||||
|
||||
else:
|
||||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Optimum export or convert_to_onnx.py export
|
||||
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
|
||||
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
|
||||
|
||||
decoder_file_name = None
|
||||
decoder_with_past_file_name = None
|
||||
for filename in os.listdir(args.hf_ort_model_path):
|
||||
if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
|
||||
continue
|
||||
if "decoder_model.onnx" in filename or f"decoder_model_{args.precision}.onnx" in filename:
|
||||
decoder_file_name = filename
|
||||
if (
|
||||
"decoder_with_past_model.onnx" in filename
|
||||
or f"decoder_with_past_model_{args.precision}.onnx" in filename
|
||||
):
|
||||
decoder_with_past_file_name = filename
|
||||
|
||||
start_time = time.time()
|
||||
model = ORTModelForCausalLM.from_pretrained(
|
||||
args.hf_ort_model_path,
|
||||
decoder_file_name=decoder_file_name,
|
||||
decoder_with_past_file_name=decoder_with_past_file_name,
|
||||
use_auth_token=args.auth,
|
||||
use_io_binding=(args.device != "cpu"),
|
||||
provider=provider,
|
||||
provider_options=provider_options,
|
||||
session_options=sess_options,
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
if args.benchmark_type == "ort":
|
||||
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
||||
logger.info(f"Loading model from {args.ort_model_path}")
|
||||
start_time = time.time()
|
||||
model = ort.InferenceSession(
|
||||
args.ort_model_path,
|
||||
sess_options,
|
||||
providers=[args.execution_provider],
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(f"Loaded model in {end_time - start_time} s")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def time_fn(args, fn, inputs):
|
||||
# Warm up
|
||||
warmup_range = (
|
||||
range(args.warmup_runs)
|
||||
if args.benchmark_type == "ort"
|
||||
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
|
||||
)
|
||||
|
||||
if args.verbose:
|
||||
outputs = fn(inputs)
|
||||
logger.info(outputs)
|
||||
|
||||
for _ in warmup_range:
|
||||
fn(inputs)
|
||||
|
||||
# Benchmark
|
||||
if args.device != "cpu":
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
bench_range = (
|
||||
range(args.num_runs)
|
||||
if args.benchmark_type == "ort"
|
||||
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
|
||||
)
|
||||
for _ in bench_range:
|
||||
fn(inputs)
|
||||
|
||||
if args.device != "cpu":
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
# Newline print after trange in order to print metrics on new lines without progress bar on same line
|
||||
if args.benchmark_type != "ort":
|
||||
logger.info("")
|
||||
|
||||
latency = (end_time - start_time) / args.num_runs
|
||||
throughput = args.batch_size / latency
|
||||
|
||||
logger.info(f"Batch Size: {args.batch_size}")
|
||||
logger.info(f"Sequence Length: {args.sequence_length}")
|
||||
logger.info(f"Latency: {latency} s")
|
||||
logger.info(f"Throughput: {throughput} qps")
|
||||
return
|
||||
|
||||
|
||||
def profile_fn(args, fn, inputs, inputs_type):
|
||||
# Filename prefix format:
|
||||
# "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
|
||||
prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
|
||||
filename = None
|
||||
|
||||
if args.benchmark_type in {"hf-pt", "hf-pt2"}:
|
||||
# Profile PyTorch kernels
|
||||
with profile( # noqa: SIM117
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("model_inference"):
|
||||
fn(inputs)
|
||||
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
|
||||
|
||||
filename = os.path.join(args.log_folder, f"{prefix}.log")
|
||||
with open(filename, "w") as f:
|
||||
f.write(prof_data)
|
||||
|
||||
else:
|
||||
# Profile ORT kernels
|
||||
fn(inputs)
|
||||
|
||||
# Set new log name for ORT profile log generated
|
||||
filename = f"{prefix}.json"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def measure_fn(args, fn, inputs):
|
||||
# Measure CPU usage
|
||||
pid = os.getpid()
|
||||
process = psutil.Process(pid)
|
||||
process.cpu_percent(interval=0.1)
|
||||
|
||||
fn(inputs)
|
||||
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
|
||||
|
||||
# Measure memory usage
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
|
||||
|
||||
# Flush output so memory usage is printed
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def run_hf_inference(args, init_inputs, iter_inputs, model):
|
||||
# Inference steps to measure
|
||||
def get_logits(inputs):
|
||||
# Inference pass without decoding
|
||||
outputs = model(**inputs)
|
||||
return outputs
|
||||
|
||||
# Examples of other inference steps that can be measured:
|
||||
# To use, uncomment the function and assign it to `generate_fn`
|
||||
|
||||
# def get_pred_ids(inputs):
|
||||
# # Inference pass with predicted token ids generation
|
||||
# predicted_ids = model.generate(**inputs)
|
||||
# return predicted_ids
|
||||
|
||||
# def gen_and_dec(inputs):
|
||||
# # Inference pass with generation and decoding
|
||||
# predicted_ids = get_pred_ids(inputs)
|
||||
# transcription = []
|
||||
# for bs in range(args.batch_size):
|
||||
# for rs in range(args.num_return_sequences):
|
||||
# transcription.append(
|
||||
# args.tokenizer.batch_decode(
|
||||
# predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
|
||||
# )[0]
|
||||
# )
|
||||
# return transcription
|
||||
|
||||
generate_fn = get_logits
|
||||
|
||||
if args.benchmark_type == "hf-pt2":
|
||||
# Run forward pass once with each set of inputs to process through Dynamo
|
||||
generate_fn(init_inputs)
|
||||
generate_fn(iter_inputs)
|
||||
|
||||
if args.profile:
|
||||
new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Turn profiling off to stop appending to log
|
||||
old_logname = model.decoder.session.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
new_logname = profile_fn(args, generate_fn, iter_inputs, "per-token")
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Turn profiling off to stop appending to log
|
||||
old_logname = model.decoder_with_past.session.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
return
|
||||
|
||||
# PyTorch evaluations
|
||||
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
|
||||
time_fn(args, generate_fn, init_inputs)
|
||||
measure_fn(args, generate_fn, init_inputs)
|
||||
|
||||
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
|
||||
time_fn(args, generate_fn, iter_inputs)
|
||||
measure_fn(args, generate_fn, iter_inputs)
|
||||
|
||||
|
||||
def run_ort_inference(args, init_inputs, iter_inputs, model):
|
||||
def prepare_ort_inputs(inputs):
|
||||
# Check that all model inputs will be provided
|
||||
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
|
||||
user_inputs = set(inputs.keys())
|
||||
missing_inputs = model_inputs - user_inputs
|
||||
if len(missing_inputs):
|
||||
logger.error(f"The following model inputs are missing: {missing_inputs}")
|
||||
raise Exception("There are missing inputs to the model. Please add them and try again.")
|
||||
|
||||
# Remove unnecessary inputs from model inputs
|
||||
unnecessary_inputs = user_inputs - model_inputs
|
||||
if len(unnecessary_inputs):
|
||||
for unnecessary_input in unnecessary_inputs:
|
||||
logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
|
||||
del inputs[unnecessary_input]
|
||||
|
||||
# Add IO bindings for non-CPU execution providers
|
||||
if args.device != "cpu":
|
||||
io_binding = model.io_binding()
|
||||
for k, v in inputs.items():
|
||||
io_binding.bind_cpu_input(k, v)
|
||||
for output in model.get_outputs():
|
||||
io_binding.bind_output(output.name)
|
||||
return io_binding
|
||||
|
||||
return inputs
|
||||
|
||||
def with_io_binding(io_binding):
|
||||
# Inference pass with IO binding
|
||||
model.run_with_iobinding(io_binding)
|
||||
|
||||
def without_io_binding(inputs):
|
||||
# Inference pass without IO binding
|
||||
outputs = model.run(None, inputs)
|
||||
return outputs
|
||||
|
||||
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
|
||||
|
||||
if args.profile:
|
||||
ort_init_inputs = prepare_ort_inputs(init_inputs)
|
||||
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
|
||||
|
||||
# Turn profiling off to stop appending to log file
|
||||
old_logname = model.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
# Re-initialize model for new log file instead of appending to old log file
|
||||
model = get_model(args)
|
||||
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
|
||||
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "per-token")
|
||||
|
||||
# Turn profiling off to stop appending to log
|
||||
old_logname = model.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
return
|
||||
|
||||
# ORT evaluations
|
||||
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
|
||||
ort_init_inputs = prepare_ort_inputs(init_inputs)
|
||||
time_fn(args, generate_fn, ort_init_inputs)
|
||||
measure_fn(args, generate_fn, ort_init_inputs)
|
||||
|
||||
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
|
||||
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
|
||||
time_fn(args, generate_fn, ort_iter_inputs)
|
||||
measure_fn(args, generate_fn, ort_iter_inputs)
|
||||
|
||||
|
||||
def run_inference(args, init_inputs, iter_inputs, model):
|
||||
if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}:
|
||||
run_hf_inference(args, init_inputs, iter_inputs, model)
|
||||
elif args.benchmark_type == "ort":
|
||||
run_ort_inference(args, init_inputs, iter_inputs, model)
|
||||
else:
|
||||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
|
||||
)
|
||||
|
||||
# Args for choosing the model
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
required=True,
|
||||
type=str,
|
||||
default="fp32",
|
||||
choices=["int8", "fp16", "fp32"],
|
||||
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-pt-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-ort-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ort-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to ONNX model",
|
||||
)
|
||||
|
||||
# Args for running and evaluating the model
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-sizes",
|
||||
default="1 2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sequence-lengths",
|
||||
default="8 16 32 64 128 256 512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
)
|
||||
parser.add_argument("-id", "--device-id", type=int, default=0)
|
||||
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
|
||||
parser.add_argument("-n", "--num-runs", type=int, default=10)
|
||||
parser.add_argument("--seed", type=int, default=2)
|
||||
|
||||
# Args for decoding logic
|
||||
parser.add_argument("--max-length", type=int, default=32)
|
||||
parser.add_argument("--num-return-sequences", type=int, default=1)
|
||||
|
||||
# Args for accessing detailed info
|
||||
parser.add_argument("--profile", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
|
||||
)
|
||||
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
|
||||
parser.add_argument("--verbose", default=False, action="store_true")
|
||||
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set seed properties
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Set runtime properties
|
||||
if "ort" in args.benchmark_type:
|
||||
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
|
||||
if args.execution_provider == "CUDAExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
elif args.execution_provider == "ROCMExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
args.device = "cuda"
|
||||
|
||||
# Check that model paths have been specified for any benchmarking with ORT
|
||||
if args.benchmark_type == "hf-ort":
|
||||
assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`"
|
||||
if args.benchmark_type == "ort":
|
||||
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
|
||||
|
||||
args.batch_sizes = args.batch_sizes.split(" ")
|
||||
args.sequence_lengths = args.sequence_lengths.split(" ")
|
||||
|
||||
# Check that only one (batch_size, sequence_length) combination is set for profiling
|
||||
if args.profile:
|
||||
assert (
|
||||
len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1
|
||||
), "Please provide only one (batch_size, sequence_length) combination for profiling"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
setup_logger(args.verbose)
|
||||
logger.info(args.__dict__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.model_name)
|
||||
config = LlamaConfig.from_pretrained(args.model_name)
|
||||
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
|
||||
use_fp16 = args.precision == "fp16"
|
||||
|
||||
setattr(args, "tokenizer", tokenizer) # noqa: B010
|
||||
setattr(args, "config", config) # noqa: B010
|
||||
setattr(args, "target_device", target_device) # noqa: B010
|
||||
setattr(args, "use_fp16", use_fp16) # noqa: B010
|
||||
|
||||
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
|
||||
model = get_model(args)
|
||||
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
|
||||
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
|
||||
setattr(args, "batch_size", int(batch_size)) # noqa: B010
|
||||
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
|
||||
|
||||
init_inputs, iter_inputs = get_inputs(args)
|
||||
run_inference(args, init_inputs, iter_inputs, model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,359 @@
|
|||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-sizes",
|
||||
type=str,
|
||||
default="1 2",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sequence-lengths",
|
||||
type=str,
|
||||
default="8 16 32 64 128 256 512",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--warmup-runs",
|
||||
type=int,
|
||||
default=5,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=1000,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-ort-model-path",
|
||||
type=str,
|
||||
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ort-model-path",
|
||||
type=str,
|
||||
help="Path to ONNX model for ORT benchmarking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name in Hugging Face",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["int8", "fp16", "fp32"],
|
||||
help="Precision to run model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
help="Device to benchmark models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device-id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="GPU device ID",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Print detailed logs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of mins to attempt the benchmark before moving on",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
|
||||
log_folder_name = f"./{args.model_size}_{args.precision}"
|
||||
setattr(args, "log_folder", log_folder_name) # noqa: B010
|
||||
os.makedirs(args.log_folder, exist_ok=True)
|
||||
|
||||
# Convert timeout value to secs
|
||||
args.timeout *= 60
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def process_log_file(device_id, log_file, base_results):
|
||||
entries = []
|
||||
batch_size, sequence_length, step = None, None, None
|
||||
latency_s, latency_ms, throughput, memory = None, None, None, None
|
||||
|
||||
batch_pattern = "Batch Size: "
|
||||
sequence_pattern = "Sequence Length: "
|
||||
prompt_step_pattern = "to get past_key_values"
|
||||
per_token_step_pattern = "with past_key_values"
|
||||
latency_pattern = "Latency: "
|
||||
throughput_pattern = "Throughput: "
|
||||
memory_pattern = "peak="
|
||||
|
||||
with open(log_file) as f:
|
||||
for input_line in f:
|
||||
line = input_line.replace("\n", "")
|
||||
|
||||
if batch_pattern in line:
|
||||
batch_size = int(line[len(batch_pattern) :])
|
||||
elif sequence_pattern in line:
|
||||
sequence_length = int(line[len(sequence_pattern) :])
|
||||
elif prompt_step_pattern in line:
|
||||
step = "prompt"
|
||||
elif per_token_step_pattern in line:
|
||||
step = "per-token"
|
||||
elif latency_pattern in line:
|
||||
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
|
||||
if step == "prompt":
|
||||
latency_s /= sequence_length
|
||||
latency_ms = latency_s * 1000
|
||||
elif throughput_pattern in line:
|
||||
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
|
||||
elif memory_pattern in line:
|
||||
if "CPU" in line:
|
||||
# Example format for log entry:
|
||||
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
|
||||
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
|
||||
else:
|
||||
# Example format for log entry:
|
||||
# GPU memory usage: before=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 69637.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}] peak=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 73861.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}]
|
||||
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
|
||||
usage = json.loads(peak)[device_id]["max_used_MB"]
|
||||
memory = float(usage) / 1000
|
||||
|
||||
# Append log entry to list of entries
|
||||
entry = base_results + [ # noqa: RUF005
|
||||
batch_size,
|
||||
sequence_length,
|
||||
step,
|
||||
latency_s,
|
||||
latency_ms,
|
||||
throughput,
|
||||
memory,
|
||||
]
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def save_results(results, filename):
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(
|
||||
results,
|
||||
columns=[
|
||||
"Engine",
|
||||
"Precision",
|
||||
"Device",
|
||||
"Batch Size",
|
||||
"Sequence Length",
|
||||
"Step",
|
||||
"Latency (s)",
|
||||
"Latency (ms)",
|
||||
"Throughput (qps)",
|
||||
"Memory (GB)",
|
||||
],
|
||||
)
|
||||
|
||||
# Set column types
|
||||
df["Batch Size"] = df["Batch Size"].astype("int")
|
||||
df["Sequence Length"] = df["Sequence Length"].astype("int")
|
||||
df["Latency (s)"] = df["Latency (s)"].astype("float")
|
||||
df["Latency (ms)"] = df["Latency (ms)"].astype("float")
|
||||
df["Throughput (qps)"] = df["Throughput (qps)"].astype("float")
|
||||
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
|
||||
|
||||
df.to_csv(filename, index=False)
|
||||
logger.info(f"Results saved in {filename}!")
|
||||
|
||||
|
||||
def benchmark(args, benchmark_cmd, engine):
|
||||
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
|
||||
log_path = os.path.join(args.log_folder, log_filename)
|
||||
with open(log_path, "w") as log_file:
|
||||
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
|
||||
try:
|
||||
process.wait(args.timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
# Create entries for csv
|
||||
logger.info("Gathering data from log files...")
|
||||
base_results = [engine, args.precision, args.device]
|
||||
results = process_log_file(args.device_id, log_path, base_results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
setup_logger(args.verbose)
|
||||
logger.info(args.__dict__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
all_results = []
|
||||
# Benchmark PyTorch without torch.compile
|
||||
benchmark_cmd = [
|
||||
"python3",
|
||||
"benchmark.py",
|
||||
"--benchmark-type",
|
||||
"hf-pt",
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--batch-sizes",
|
||||
args.batch_sizes,
|
||||
"--sequence-lengths",
|
||||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
"--auth",
|
||||
]
|
||||
logger.info("Benchmark PyTorch without torch.compile")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch")
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark PyTorch with torch.compile
|
||||
benchmark_cmd = [
|
||||
"python3",
|
||||
"benchmark.py",
|
||||
"--benchmark-type",
|
||||
"hf-pt2",
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--batch-sizes",
|
||||
args.batch_sizes,
|
||||
"--sequence-lengths",
|
||||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
"--auth",
|
||||
]
|
||||
logger.info("Benchmark PyTorch with torch.compile")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch-2")
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark Optimum + ONNX Runtime
|
||||
if args.hf_ort_model_path:
|
||||
benchmark_cmd = [
|
||||
"python3",
|
||||
"benchmark.py",
|
||||
"--benchmark-type",
|
||||
"hf-ort",
|
||||
"--hf-ort-model-path",
|
||||
args.hf_ort_model_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--batch-sizes",
|
||||
args.batch_sizes,
|
||||
"--sequence-lengths",
|
||||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
"--auth",
|
||||
]
|
||||
logger.info("Benchmark Optimum + ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch-ort")
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark ONNX Runtime
|
||||
if args.ort_model_path:
|
||||
benchmark_cmd = [
|
||||
"python3",
|
||||
"benchmark.py",
|
||||
"--benchmark-type",
|
||||
"ort",
|
||||
"--ort-model-path",
|
||||
args.ort_model_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--batch-sizes",
|
||||
args.batch_sizes,
|
||||
"--sequence-lengths",
|
||||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
]
|
||||
logger.info("Benchmark ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "onnxruntime")
|
||||
all_results.extend(results)
|
||||
|
||||
csv_file = f"{args.model_size}_{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
|
||||
save_results(all_results, os.path.join(args.log_folder, csv_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,576 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from benchmark_helper import Precision, prepare_environment, setup_logger
|
||||
from llama_inputs import get_sample_inputs, get_sample_with_past_kv_inputs
|
||||
from llama_parity import main as parity_check
|
||||
from onnx_model import OnnxModel
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from onnxruntime import quantization as ort_quantization
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def get_model_dynamic_axes(input_names: List[str], output_names: List[str]):
|
||||
dynamic_axes = {}
|
||||
for name in input_names + output_names:
|
||||
if name in input_names:
|
||||
# shape is (batch_size, sequence_length)
|
||||
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
|
||||
elif name == "logits":
|
||||
# shape is (batch_size, sequence_length, vocab_size)
|
||||
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
|
||||
elif "present" in name:
|
||||
# shape is (batch_size, num_heads, sequence_length, head_size)
|
||||
dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
|
||||
else:
|
||||
raise Exception("Unknown input or output name found")
|
||||
return dynamic_axes
|
||||
|
||||
|
||||
def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: List[str]):
|
||||
dynamic_axes = {}
|
||||
for name in input_names + output_names:
|
||||
if name in {"input_ids", "position_ids"}:
|
||||
# shape is (batch_size, 1)
|
||||
dynamic_axes[name] = {0: "batch_size"}
|
||||
elif name == "attention_mask":
|
||||
# shape is (batch_size, past_sequence_length + 1)
|
||||
dynamic_axes[name] = {0: "batch_size", 1: "past_sequence_length + 1"}
|
||||
elif "past" in name:
|
||||
# shape is (batch_size, num_heads, past_sequence_length, head_size)
|
||||
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
|
||||
elif name == "logits":
|
||||
# shape is (batch_size, 1, vocab_size)
|
||||
dynamic_axes[name] = {0: "batch_size"}
|
||||
elif "present" in name:
|
||||
# shape is (batch_size, num_heads, past_sequence_length + 1, head_size)
|
||||
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length + 1"}
|
||||
else:
|
||||
raise Exception("Unknown input or output name found")
|
||||
return dynamic_axes
|
||||
|
||||
|
||||
def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: str):
|
||||
onnx.save(
|
||||
onnx_model,
|
||||
output_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=data_path,
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
|
||||
|
||||
# Notes:
|
||||
# 1) Dynamo export will not work automatically until this issue is resolved: https://github.com/microsoft/onnxscript/issues/493
|
||||
#
|
||||
# 2) Dynamo export will run manually if you set the ONNX file path to the same path that you use to save the model after export.
|
||||
# In other words, the value of `temp_path` should be set as the ONNX file path. You can open the issue in your browser to find
|
||||
# the location in ONNX Script where you have to make this change.
|
||||
#
|
||||
# Once the issue is resolved, we hope to modify the code below as follows for each export.
|
||||
#
|
||||
# Before:
|
||||
# temp_dir = args.output
|
||||
# temp_path = os.path.join(temp_dir, "temp.onnx")
|
||||
# ...
|
||||
# ...
|
||||
# ...
|
||||
# del onnx_model
|
||||
# os.system(f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}")
|
||||
#
|
||||
#
|
||||
# After:
|
||||
# temp_dir = tempfile.TemporaryDirectory()
|
||||
# temp_path = os.path.join(temp_dir.name, "temp.onnx")
|
||||
# ...
|
||||
# ...
|
||||
# ...
|
||||
# del onnx_model
|
||||
# temp_dir.cleanup()
|
||||
#
|
||||
def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
|
||||
from torch._dynamo import config
|
||||
|
||||
config.capture_scalar_outputs = True
|
||||
|
||||
# Dummy values for export
|
||||
batch_size, sequence_length = 2, 8
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Export decoder_model.onnx
|
||||
input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length)
|
||||
temp_dir = args.output # tempfile.TemporaryDirectory()
|
||||
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
|
||||
torch.onnx.dynamo_export(
|
||||
llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||
).save(temp_path)
|
||||
|
||||
# Check decoder_model.onnx and save all external data to one file
|
||||
onnx.checker.check_model(temp_path)
|
||||
onnx.shape_inference.infer_shapes_path(temp_path)
|
||||
|
||||
output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
|
||||
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
||||
save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_model_fp32.onnx.data")
|
||||
del onnx_model
|
||||
os.system(
|
||||
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
|
||||
) # temp_dir.cleanup()
|
||||
|
||||
# Export decoder_with_past_model.onnx
|
||||
input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs(
|
||||
l_config, device, batch_size, sequence_length
|
||||
)
|
||||
temp_dir = args.output # tempfile.TemporaryDirectory()
|
||||
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
|
||||
torch.onnx.dynamo_export(
|
||||
llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||
).save(temp_path)
|
||||
|
||||
# Check decoder_with_past_model.onnx and save all external data to one file
|
||||
onnx.checker.check_model(temp_path)
|
||||
onnx.shape_inference.infer_shapes_path(temp_path)
|
||||
|
||||
output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx")
|
||||
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
||||
save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_with_past_model_fp32.onnx.data")
|
||||
del onnx_model
|
||||
os.system(
|
||||
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
|
||||
) # temp_dir.cleanup()
|
||||
|
||||
logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")
|
||||
|
||||
|
||||
def run_torchscript_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
|
||||
# Dummy values for export
|
||||
batch_size, sequence_length = 2, 8
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Export decoder_model.onnx
|
||||
decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length)
|
||||
|
||||
input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
output_names = [
|
||||
"logits",
|
||||
*list(
|
||||
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
|
||||
),
|
||||
]
|
||||
dynamic_axes = get_model_dynamic_axes(input_names, output_names)
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
temp_path = os.path.join(temp_dir.name, "temp.onnx")
|
||||
torch.onnx.export(
|
||||
llama,
|
||||
args=decoder_inputs,
|
||||
f=temp_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=13,
|
||||
do_constant_folding=True,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
# Check decoder_model.onnx and save all external data to one file
|
||||
onnx.checker.check_model(temp_path)
|
||||
onnx.shape_inference.infer_shapes_path(temp_path)
|
||||
|
||||
output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
|
||||
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
||||
save_onnx_model(
|
||||
onnx_model,
|
||||
output_path,
|
||||
f"{args.model_name}_decoder_model_fp32.onnx.data",
|
||||
)
|
||||
del onnx_model
|
||||
temp_dir.cleanup()
|
||||
|
||||
# Export decoder_with_past_model.onnx
|
||||
decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length)
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"position_ids",
|
||||
*list(
|
||||
chain.from_iterable(
|
||||
(f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers)
|
||||
)
|
||||
),
|
||||
]
|
||||
output_names = [
|
||||
"logits",
|
||||
*list(
|
||||
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
|
||||
),
|
||||
]
|
||||
dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names)
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
temp_path = os.path.join(temp_dir.name, "temp.onnx")
|
||||
torch.onnx.export(
|
||||
llama,
|
||||
args=decoder_with_past_inputs,
|
||||
f=temp_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=13,
|
||||
do_constant_folding=True,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
# Check decoder_with_past_model.onnx and save all external data to one file
|
||||
onnx.checker.check_model(temp_path)
|
||||
onnx.shape_inference.infer_shapes_path(temp_path)
|
||||
|
||||
output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx")
|
||||
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
||||
save_onnx_model(
|
||||
onnx_model,
|
||||
output_path,
|
||||
f"{args.model_name}_decoder_with_past_model_fp32.onnx.data",
|
||||
)
|
||||
del onnx_model
|
||||
temp_dir.cleanup()
|
||||
|
||||
logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!")
|
||||
|
||||
|
||||
def remove_existing_files(output_path: str):
|
||||
for filename in os.listdir(output_path):
|
||||
filepath = os.path.join(output_path, filename)
|
||||
if ".onnx" in filename or ".onnx.data" in filename:
|
||||
os.remove(filepath)
|
||||
logger.warning(f"Removing {filepath}")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_name",
|
||||
required=True,
|
||||
help="Model name in Hugging Face",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input",
|
||||
required=False,
|
||||
default=os.path.join("."),
|
||||
help="Directory path to PyTorch model and associated files if saved on disk",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
required=False,
|
||||
default=os.path.join(".", "llama_onnx_models"),
|
||||
help="Directory path to save exported model files in",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
required=False,
|
||||
type=Precision,
|
||||
default=Precision.FLOAT32,
|
||||
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
|
||||
help="Precision to export model in",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--execution_provider",
|
||||
required=False,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
help="Execution provider to verify parity with",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantization_method",
|
||||
default="",
|
||||
choices=["smooth_quant", "quantize_dynamic"],
|
||||
help="Run a specific quantization algorithm. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.",
|
||||
)
|
||||
|
||||
smooth_quant_group = parser.add_argument_group("smooth_quant")
|
||||
|
||||
smooth_quant_group.add_argument(
|
||||
"--smooth_quant_alpha",
|
||||
required=False,
|
||||
default=0.8,
|
||||
type=float,
|
||||
help="Strength to control migration difficulty from activation to weights. Default is 0.8 to match value \
|
||||
used in original paper for LLaMA. Paper recommends using values in [0.4, 0.6] range. \
|
||||
Link to paper: https://arxiv.org/pdf/2211.10438.pdf",
|
||||
)
|
||||
|
||||
smooth_quant_group.add_argument(
|
||||
"--smooth_quant_dataset",
|
||||
required=False,
|
||||
default="NeelNanda/pile-10k",
|
||||
help="Path to dataset for calibration during quantization",
|
||||
)
|
||||
|
||||
smooth_quant_group.add_argument(
|
||||
"--pad_max",
|
||||
required=False,
|
||||
default=196,
|
||||
type=int,
|
||||
help="Max padding size",
|
||||
)
|
||||
|
||||
smooth_quant_group.add_argument(
|
||||
"--calibration_sampling_size",
|
||||
required=False,
|
||||
type=int,
|
||||
default=8,
|
||||
help="Calibration sampling size for quantization config",
|
||||
)
|
||||
|
||||
smooth_quant_group.add_argument(
|
||||
"--nc_workspace",
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join(".", "nc_workspace"),
|
||||
help="Workspace to save intermediate files generated by Intel's Neural Compressor package.",
|
||||
)
|
||||
|
||||
quantize_dynamic_group = parser.add_argument_group("quantize_dynamic")
|
||||
|
||||
quantize_dynamic_group.add_argument(
|
||||
"--quantize_embedding_layer",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Quantize MatMul, GEMM, and Gather.",
|
||||
)
|
||||
quantize_dynamic_group.set_defaults(quantize_embedding_layer=False)
|
||||
|
||||
quantize_dynamic_group.add_argument(
|
||||
"--quantize_per_channel",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Quantize weights per each channel.",
|
||||
)
|
||||
quantize_dynamic_group.set_defaults(quantize_per_channel=False)
|
||||
|
||||
quantize_dynamic_group.add_argument(
|
||||
"--quantize_reduce_range",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Quantize weights with 7 bits.",
|
||||
)
|
||||
quantize_dynamic_group.set_defaults(quantize_reduce_range=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Print verbose logs",
|
||||
)
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--use_dynamo_export",
|
||||
action="store_true",
|
||||
help="Use the new Dynamo exporter instead of the old TorchScript exporter",
|
||||
)
|
||||
parser.set_defaults(use_dynamo_export=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
setup_logger(args.verbose)
|
||||
prepare_environment(args.input, args.output, args.execution_provider != "cpu")
|
||||
remove_existing_files(args.output)
|
||||
logger.info(f"Arguments: {args}")
|
||||
|
||||
# Load model and config
|
||||
use_auth_token = args.input == os.path.join(".")
|
||||
setattr(args, "use_auth_token", use_auth_token) # noqa: B010
|
||||
l_config = LlamaConfig.from_pretrained(
|
||||
args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token
|
||||
)
|
||||
llama = LlamaForCausalLM.from_pretrained(
|
||||
args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token, use_cache=True
|
||||
)
|
||||
original_model_name = args.model_name
|
||||
setattr(args, "original_model_name", original_model_name) # noqa: B010
|
||||
args.model_name = args.model_name.split("/")[-1]
|
||||
|
||||
# Export to ONNX
|
||||
if args.use_dynamo_export:
|
||||
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
|
||||
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
|
||||
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
|
||||
logger.warning(
|
||||
"Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
|
||||
)
|
||||
logger.warning(
|
||||
"Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
|
||||
)
|
||||
run_dynamo_export(args, l_config, llama)
|
||||
else:
|
||||
run_torchscript_export(args, l_config, llama)
|
||||
|
||||
# Change precision of exported models if not FP32
|
||||
decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
|
||||
decoder_with_past_model_fp32_path = os.path.join(
|
||||
args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx"
|
||||
)
|
||||
|
||||
if args.precision == Precision.FLOAT16:
|
||||
# Convert decoder_model.onnx to FP16
|
||||
decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx")
|
||||
model = OnnxModel(onnx.load_model(decoder_model_fp32_path, load_external_data=True))
|
||||
model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"])
|
||||
model.save_model_to_file(decoder_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True)
|
||||
del model
|
||||
|
||||
# Convert decoder_with_past_model.onnx to FP16
|
||||
decoder_with_past_model_fp16_path = os.path.join(
|
||||
args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx"
|
||||
)
|
||||
model = OnnxModel(onnx.load_model(decoder_with_past_model_fp32_path, load_external_data=True))
|
||||
model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"])
|
||||
model.save_model_to_file(
|
||||
decoder_with_past_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True
|
||||
)
|
||||
del model
|
||||
|
||||
elif args.precision == Precision.INT8:
|
||||
decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx")
|
||||
decoder_with_past_model_int8_path = os.path.join(
|
||||
args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx"
|
||||
)
|
||||
|
||||
if args.quantization_method == "smooth_quant":
|
||||
from neural_compressor import PostTrainingQuantConfig
|
||||
from neural_compressor import quantization as intel_quantization
|
||||
from neural_compressor import set_workspace
|
||||
from onnx.external_data_helper import load_external_data_for_model
|
||||
from quant_kv_dataloader import QuantKVDataLoader
|
||||
|
||||
set_workspace(args.nc_workspace)
|
||||
quantization_config = PostTrainingQuantConfig(
|
||||
calibration_sampling_size=[args.calibration_sampling_size],
|
||||
recipes={
|
||||
"optypes_to_exclude_output_quant": ["MatMul"],
|
||||
"smooth_quant": args.smooth_quant,
|
||||
"smooth_quant_args": {"alpha": args.smooth_quant_alpha},
|
||||
},
|
||||
op_type_dict={
|
||||
"^((?!(MatMul|Gather|Conv)).)*$": {
|
||||
"weight": {"dtype": ["fp32"]},
|
||||
"activation": {"dtype": ["fp32"]},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Convert decoder_model.onnx to INT8
|
||||
decoder_model_int8 = intel_quantization.fit(
|
||||
decoder_model_fp32_path,
|
||||
quantization_config,
|
||||
calib_dataloader=QuantKVDataLoader(args),
|
||||
)
|
||||
load_external_data_for_model(
|
||||
decoder_model_int8._model,
|
||||
os.path.split(decoder_model_int8._model_path)[0],
|
||||
)
|
||||
save_onnx_model(
|
||||
decoder_model_int8._model,
|
||||
decoder_model_int8_path,
|
||||
f"{args.model_name}_decoder_model_int8.onnx.data",
|
||||
)
|
||||
del decoder_model_int8
|
||||
|
||||
# Convert decoder_with_past_model.onnx to INT8
|
||||
decoder_with_past_model_int8 = intel_quantization.fit(
|
||||
decoder_with_past_model_fp32_path,
|
||||
quantization_config,
|
||||
calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path),
|
||||
)
|
||||
load_external_data_for_model(
|
||||
decoder_with_past_model_int8._model,
|
||||
os.path.split(decoder_with_past_model_int8._model_path)[0],
|
||||
)
|
||||
save_onnx_model(
|
||||
decoder_with_past_model_int8._model,
|
||||
decoder_with_past_model_int8_path,
|
||||
f"{args.model_name}_decoder_with_past_model_int8.onnx.data",
|
||||
)
|
||||
del decoder_with_past_model_int8
|
||||
|
||||
logger.info(f"Removing {args.nc_workspace}")
|
||||
os.system(f"rm -R {args.nc_workspace}")
|
||||
|
||||
elif args.quantization_method == "quantize_dynamic":
|
||||
logger.warning(
|
||||
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
|
||||
)
|
||||
|
||||
# Convert decoder_model.onnx to INT8
|
||||
ort_quantization.quantize_dynamic(
|
||||
decoder_model_fp32_path,
|
||||
decoder_model_int8_path,
|
||||
op_types_to_quantize=["MatMul", "Gemm", "Gather"]
|
||||
if args.quantize_embedding_layer
|
||||
else ["MatMul", "Gemm"],
|
||||
per_channel=args.quantize_per_channel,
|
||||
reduce_range=args.quantize_reduce_range,
|
||||
use_external_data_format=True,
|
||||
extra_options={"MatMulConstBOnly": True},
|
||||
)
|
||||
|
||||
# Convert decoder_with_past_model.onnx to INT8
|
||||
ort_quantization.quantize_dynamic(
|
||||
decoder_with_past_model_fp32_path,
|
||||
decoder_with_past_model_int8_path,
|
||||
op_types_to_quantize=["MatMul", "Gemm", "Gather"]
|
||||
if args.quantize_embedding_layer
|
||||
else ["MatMul", "Gemm"],
|
||||
per_channel=args.quantize_per_channel,
|
||||
reduce_range=args.quantize_reduce_range,
|
||||
use_external_data_format=True,
|
||||
extra_options={"MatMulConstBOnly": True},
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
|
||||
|
||||
# Verify parity on all saved ONNX models
|
||||
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
|
||||
logger.info("Verifying parity on all ONNX models created")
|
||||
for filename in os.listdir(args.output):
|
||||
if ".data" in filename or ".onnx" not in filename:
|
||||
continue
|
||||
|
||||
precision = filename[filename.rfind("_") + 1 : filename.find(".onnx")]
|
||||
parity_cmd = ["-m", f"{original_model_name}", "-o", f"{os.path.join(args.output, filename)}", "-fp", precision]
|
||||
if "with_past" in filename:
|
||||
parity_cmd.append("--use_past_kv")
|
||||
parity_check(parity_cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import LlamaConfig
|
||||
|
||||
|
||||
# Get position_ids from attention_mask
|
||||
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
if use_past_kv:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
return position_ids
|
||||
|
||||
|
||||
# Inputs for first pass to get initial past_key_values
|
||||
def get_sample_inputs(
|
||||
config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, return_dict: bool = False
|
||||
):
|
||||
input_ids = torch.randint(
|
||||
low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64
|
||||
)
|
||||
attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64)
|
||||
# position_ids is of shape (batch_size, seq_len)
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
||||
|
||||
if not return_dict:
|
||||
return (input_ids, attention_mask, position_ids)
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
return inputs
|
||||
|
||||
|
||||
# Inputs for subsequent passes with past_key_values
|
||||
def get_sample_with_past_kv_inputs(
|
||||
config: LlamaConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
past_seq_len: int,
|
||||
use_fp16: bool = False,
|
||||
return_dict: bool = False,
|
||||
):
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64)
|
||||
# position_ids is of shape (batch_size, 1)
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=True)
|
||||
past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16)
|
||||
|
||||
if not return_dict:
|
||||
return (input_ids, attention_mask, position_ids, past_kv)
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_kv,
|
||||
}
|
||||
return inputs
|
||||
|
||||
|
||||
# Create past_key_values
|
||||
def get_sample_past_kv_inputs(
|
||||
config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool
|
||||
):
|
||||
num_heads, head_size = config.num_attention_heads, config.hidden_size // config.num_attention_heads
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
past_kv = [
|
||||
(
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
return past_kv
|
||||
|
||||
|
||||
# Convert list of past_kv to dict of past_key and past_value
|
||||
def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool):
|
||||
past_kv = {}
|
||||
np_dtype = np.float16 if use_fp16 else np.float32
|
||||
for i, (past_k, past_v) in enumerate(past_key_values):
|
||||
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype)
|
||||
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype)
|
||||
return past_kv
|
||||
|
||||
|
||||
# Format PyTorch inputs to ONNX Runtime inputs
|
||||
def convert_inputs_for_ort(pt_inputs: dict, use_fp16: bool):
|
||||
ort_inputs = {}
|
||||
for k, v in pt_inputs.items():
|
||||
if k == "past_key_values":
|
||||
ort_inputs.update(flatten_past_kv_inputs(v, use_fp16))
|
||||
else:
|
||||
ort_inputs[k] = v.detach().cpu().numpy()
|
||||
return ort_inputs
|
||||
|
||||
|
||||
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
||||
def get_msft_sample_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool):
|
||||
np_dtype = np.float16 if use_fp16 else np.float32
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
max_seq_len = 2048
|
||||
|
||||
ort_inputs = {
|
||||
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
|
||||
"attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
|
||||
"k_cache": np.random.rand(
|
||||
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
|
||||
).astype(np_dtype),
|
||||
"v_cache": np.random.rand(
|
||||
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
|
||||
).astype(np_dtype),
|
||||
"pos": np.array(past_seq_len, dtype=np.int64),
|
||||
}
|
||||
return ort_inputs
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from benchmark_helper import create_onnxruntime_session, setup_logger
|
||||
from llama_inputs import convert_inputs_for_ort, get_sample_inputs, get_sample_with_past_kv_inputs
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM):
|
||||
# Dummy values for parity
|
||||
batch_size, sequence_length = 2, 8
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Run inference with PyTorch
|
||||
inputs = (
|
||||
get_sample_inputs(config, device, batch_size, sequence_length, return_dict=True)
|
||||
if not args.use_past_kv
|
||||
else get_sample_with_past_kv_inputs(
|
||||
config, device, batch_size, sequence_length, use_fp16=(args.precision == "fp16"), return_dict=True
|
||||
)
|
||||
)
|
||||
pt_outputs = pt_model(**inputs).logits.detach().cpu().numpy()
|
||||
|
||||
# Run inference with ORT
|
||||
inputs = convert_inputs_for_ort(inputs, use_fp16=(args.precision == "fp16"))
|
||||
ort_model = create_onnxruntime_session(
|
||||
args.onnx_model_path,
|
||||
args.execution_provider != "cpu", # use_gpu
|
||||
provider=args.execution_provider,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
ort_outputs = ort_model.run(None, inputs)[0]
|
||||
|
||||
# Compare PyTorch and ONNX Runtime accuracy
|
||||
tol = 1e-3 if args.precision == "fp32" else 1e-2 if args.precision == "fp16" else 1e2
|
||||
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
|
||||
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
|
||||
if not parity:
|
||||
logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
|
||||
|
||||
|
||||
def get_args(argv: List[str]):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_name",
|
||||
required=True,
|
||||
help="Model name in Hugging Face",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--torch_model_directory",
|
||||
required=False,
|
||||
default=os.path.join("."),
|
||||
help="Path to folder containing PyTorch model and associated files if saved on disk",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--onnx_model_path",
|
||||
required=True,
|
||||
default=os.path.join("."),
|
||||
help="Path to ONNX model (with external data files saved in the same folder as the model)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-ep",
|
||||
"--execution_provider",
|
||||
required=False,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
help="Execution provider to verify parity with",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Print verbose logs",
|
||||
)
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--use_past_kv",
|
||||
action="store_true",
|
||||
help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.",
|
||||
)
|
||||
parser.set_defaults(use_past_kv=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-fp",
|
||||
"--precision",
|
||||
required=True,
|
||||
choices=["int8", "fp16", "fp32"],
|
||||
help="Precision of model",
|
||||
)
|
||||
|
||||
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
|
||||
return args
|
||||
|
||||
|
||||
def main(argv: List[str] = []): # noqa: B006
|
||||
args = get_args(argv)
|
||||
setup_logger(args.verbose)
|
||||
logger.info(f"Arguments: {args}")
|
||||
|
||||
# Load model and config
|
||||
use_auth_token = args.torch_model_directory == os.path.join(".")
|
||||
location = args.model_name if use_auth_token else args.torch_model_directory
|
||||
|
||||
config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token)
|
||||
llama = LlamaForCausalLM.from_pretrained(
|
||||
location,
|
||||
torch_dtype=(torch.float16 if args.precision == "fp16" else torch.float32),
|
||||
use_auth_token=use_auth_token,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
verify_parity(args, config, llama)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from benchmark_helper import create_onnxruntime_session
|
||||
from datasets import load_dataset
|
||||
from llama_inputs import get_position_ids
|
||||
from torch.nn.functional import pad
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
|
||||
class QuantKVDataLoader:
|
||||
def __init__(self, args: argparse.Namespace, onnx_model_path: str = ""):
|
||||
self.batch_size = 1
|
||||
self.pad_max = args.pad_max
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.original_model_name, use_auth_token=args.use_auth_token)
|
||||
dataset = load_dataset(args.smooth_quant_dataset, split="train")
|
||||
dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)
|
||||
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
|
||||
|
||||
self.dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=self.collate_batch,
|
||||
)
|
||||
self.decoder_model = (
|
||||
create_onnxruntime_session(
|
||||
onnx_model_path,
|
||||
args.execution_provider != "cpu", # use_gpu
|
||||
provider=args.execution_provider,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
if onnx_model_path
|
||||
else None
|
||||
)
|
||||
|
||||
def collate_batch(self, batch):
|
||||
input_ids_batched = []
|
||||
attention_mask_batched = []
|
||||
position_ids_batched = []
|
||||
labels = []
|
||||
|
||||
for text in batch:
|
||||
# Set inputs for model
|
||||
input_ids = text["input_ids"]
|
||||
attention_mask = torch.ones(len(input_ids))
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
||||
label = len(input_ids) - 1
|
||||
|
||||
# Pad input data because all model inputs must have same shape
|
||||
pad_len = self.pad_max - input_ids.shape[0]
|
||||
input_ids = pad(input_ids, (0, pad_len), value=1)
|
||||
attention_mask = pad(attention_mask, (0, pad_len), value=0)
|
||||
position_ids = pad(position_ids, (0, pad_len), value=0)
|
||||
|
||||
input_ids_batched.append(input_ids)
|
||||
attention_mask_batched.append(attention_mask)
|
||||
position_ids_batched.append(position_ids)
|
||||
labels.append(label)
|
||||
|
||||
input_ids_batched = torch.vstack(input_ids_batched)
|
||||
attention_mask_batched = torch.vstack(attention_mask_batched)
|
||||
position_ids_batched = torch.vstack(position_ids_batched)
|
||||
labels = torch.tensor(labels)
|
||||
|
||||
return (input_ids_batched, attention_mask_batched, position_ids_batched), labels
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
for (input_ids, attention_mask, position_ids), labels in self.dataloader:
|
||||
# Inputs for decoder_model.onnx
|
||||
inputs = {
|
||||
"input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
|
||||
"attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64),
|
||||
"position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
|
||||
}
|
||||
label = labels.detach().cpu().numpy()
|
||||
|
||||
if self.decoder_model is not None:
|
||||
# Run decoder_model.onnx to get inputs for decoder_with_past_model.onnx
|
||||
outputs = self.decoder_model.run(None, inputs)
|
||||
|
||||
for i in range(int((len(outputs) - 1) / 2)):
|
||||
inputs[f"past_key_values.{i}.key"] = outputs[i * 2 + 1]
|
||||
inputs[f"past_key_values.{i}.value"] = outputs[i * 2 + 2]
|
||||
past_sequence_length = inputs["past_key_values.0.key"].shape[2]
|
||||
|
||||
inputs["input_ids"] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype(np.int64)
|
||||
attn_mask_torch = torch.ones((self.batch_size, past_sequence_length + 1), dtype=torch.int64)
|
||||
inputs["attention_mask"] = attn_mask_torch.detach().cpu().numpy().astype(np.int64)
|
||||
inputs["position_ids"] = (
|
||||
get_position_ids(attn_mask_torch, use_past_kv=True).detach().cpu().numpy().astype(np.int64)
|
||||
)
|
||||
|
||||
# Yield (inputs, label) tuple for Intel's Neural Compressor:
|
||||
# https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62
|
||||
yield (inputs, label)
|
||||
|
||||
except StopIteration:
|
||||
return
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
-r requirements.txt
|
||||
torch>=2.0.1
|
||||
onnxruntime>=1.16.0
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
-r requirements.txt
|
||||
# Please manually install torch>=2.0.1 with CUDA enabled for the CUDA version installed in your system.
|
||||
# Instructions can be found here: https://pytorch.org/get-started/locally/
|
||||
onnxruntime-gpu>=1.16.0
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
-r requirements-cpu.txt
|
||||
neural-compressor>=2.2.1
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
git+https://github.com/kunal-vaishnavi/optimum.git@kvaishnavi/llama-add-position-ids
|
||||
transformers>=4.28.1
|
||||
onnx>=1.14.0
|
||||
datasets>=2.8.0
|
||||
protobuf==3.20.2
|
||||
1
setup.py
1
setup.py
|
|
@ -470,6 +470,7 @@ packages = [
|
|||
"onnxruntime.transformers.models.bart",
|
||||
"onnxruntime.transformers.models.bert",
|
||||
"onnxruntime.transformers.models.gpt2",
|
||||
"onnxruntime.transformers.models.llama",
|
||||
"onnxruntime.transformers.models.longformer",
|
||||
"onnxruntime.transformers.models.t5",
|
||||
"onnxruntime.transformers.models.stable_diffusion",
|
||||
|
|
|
|||
Loading…
Reference in a new issue