onnxruntime/onnxruntime/python/tools/transformers/models/whisper/benchmark.py
kunal-vaishnavi 4b3477f171
Add Whisper scripts (#17043)
### Description
This PR adds benchmark scripts for Whisper. It is a follow-up to [this
PR](https://github.com/microsoft/onnxruntime/pull/17020) that adds the
LLaMA scripts.



### Motivation and Context
This PR enables benchmarking Whisper across various configurations.
2023-08-22 18:14:44 -07:00

550 lines
20 KiB
Python

import argparse
import ast
import datetime
import gc
import logging
import os
import sys
import time
import numpy as np
import psutil
import torch
import whisper
from benchmark_helper import setup_logger
from onnxruntime_extensions import get_library_path
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import trange
from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
import onnxruntime as ort
from onnxruntime.transformers.benchmark_helper import measure_memory
logger = logging.getLogger(__name__)
def get_inputs(args: argparse.Namespace):
if args.benchmark_type not in {"hf-pt", "hf-pt2", "hf-ort", "ort"}:
raise Exception("Unable to auto-detect inputs for provided model")
def load_via_ffmpeg():
audio = whisper.load_audio(args.audio_path)
audio = whisper.pad_or_trim(audio)
return audio
def load_via_numpy():
with open(args.audio_path, "rb") as f:
audio = np.asarray(list(f.read()), dtype=np.uint8)
audio = np.array([audio])
return audio
inputs = {
"max_length": args.max_length,
"min_length": args.min_length,
"num_beams": args.num_beams,
"num_return_sequences": args.num_return_sequences,
"length_penalty": args.length_penalty,
"repetition_penalty": args.repetition_penalty,
}
if args.benchmark_type == "ort":
# convert_to_onnx export or ONNX E2E solution created by Olive
for k, v in inputs.items():
inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
if args.has_decoder_input_ids:
inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
if args.has_logits_processor:
inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
# Measure time taken to load audio file
logger.info(f"Load audio: {args.audio_path}")
load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
time_fn(args, load_audio_fn, args.has_audio_stream)
audio_data = load_audio_fn(args.has_audio_stream)
if args.has_audio_stream:
# ONNX E2E solution created by Olive
inputs["audio_stream"] = audio_data
return inputs
# Measure time taken to get input features
logger.info("Feature extraction: ")
return_type = "np" if args.benchmark_type == "ort" else "pt"
processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
[audio], return_tensors=return_type, sampling_rate=args.sampling_rate
).input_features
time_fn(args, processor_fn, audio_data)
input_features = processor_fn(audio_data)
if args.benchmark_type == "ort":
# convert_to_onnx export
inputs["input_features"] = input_features
return inputs
inputs["inputs"] = input_features.to(
dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
)
inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
inputs["early_stopping"] = True
inputs["use_cache"] = True
if args.decoder_input_ids:
inputs["forced_decoder_ids"] = args.decoder_input_ids
return inputs
def get_model(args: argparse.Namespace):
model, sess_options = None, None
start_time, end_time = None, None
# There are multiple sources that the model could come from:
# 1) Benchmark Whisper from Hugging Face
# 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
# 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
if args.benchmark_type in {"hf-pt", "hf-pt2"}:
source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
start_time = time.time()
model = AutoModelForSpeechSeq2Seq.from_pretrained(
source,
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
use_cache=True,
).to(args.target_device)
end_time = time.time()
if args.benchmark_type == "hf-pt2":
model = torch.compile(model)
elif args.benchmark_type in {"hf-ort", "ort"}:
sess_options = ort.SessionOptions()
sess_options.enable_profiling = args.profile
sess_options.register_custom_ops_library(get_library_path())
if args.verbose:
sess_options.log_verbosity_level = 1
sess_options.log_severity_level = 1
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
if args.benchmark_type == "hf-ort":
# Optimum export
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
start_time = time.time()
model = ORTModelForSpeechSeq2Seq.from_pretrained(
args.hf_ort_model_path,
use_io_binding=(args.device != "cpu"),
provider=provider,
provider_options=provider_options,
session_options=sess_options,
)
end_time = time.time()
if args.benchmark_type == "ort":
# convert_to_onnx.py export
logger.info(f"Loading model from {args.ort_model_path}")
start_time = time.time()
model = ort.InferenceSession(
args.ort_model_path,
sess_options,
providers=[args.execution_provider],
)
end_time = time.time()
logger.info(f"Loaded model in {end_time - start_time} s")
return model
def time_fn(args, fn, inputs):
# Warm up
warmup_range = (
range(args.warmup_runs)
if args.benchmark_type == "ort"
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
)
if args.verbose:
outputs = fn(inputs)
logger.info(outputs)
for _ in warmup_range:
fn(inputs)
# Benchmark
if args.device != "cpu":
torch.cuda.synchronize()
start_time = time.time()
bench_range = (
range(args.num_runs)
if args.benchmark_type == "ort"
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
fn(inputs)
if args.device != "cpu":
torch.cuda.synchronize()
end_time = time.time()
# Newline print after trange in order to print metrics on new lines without progress bar on same line
if args.benchmark_type != "ort":
logger.info("")
batch_size = 1
latency = (end_time - start_time) / args.num_runs
throughput = batch_size / latency
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} qps")
return
def profile_fn(args, fn, inputs, inputs_type):
# Filename prefix format:
# "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
if args.benchmark_type in {"hf-pt", "hf-pt2"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
) as prof:
with record_function("model_inference"):
fn(inputs)
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
filename = os.path.join(args.log_folder, f"{prefix}.log")
with open(filename, "w") as f:
f.write(prof_data)
else:
# Profile ORT kernels
fn(inputs)
# Set new log name for ORT profile log generated
filename = f"{prefix}.json"
return filename
def measure_fn(args, fn, inputs):
# Measure CPU usage
pid = os.getpid()
process = psutil.Process(pid)
process.cpu_percent(interval=0.1)
fn(inputs)
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
# Measure memory usage
gc.collect()
torch.cuda.empty_cache()
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
# Flush output so memory usage is printed
sys.stdout.flush()
def run_hf_inference(args, inputs, model):
# Inference steps to measure
def get_pred_ids(inputs):
# Inference pass with predicted token ids generation
predicted_ids = model.generate(**inputs)
return predicted_ids, [""]
def gen_and_dec(inputs):
# Inference pass with generation and decoding
predicted_ids = get_pred_ids(inputs)
transcription = []
for _ in range(args.num_return_sequences):
transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
return predicted_ids, transcription
# Examples of other inference steps that can be measured:
# To use, uncomment the function and assign it to `generate_fn`
# def get_logits(inputs):
# # Inference pass without decoding
# outputs = model(**inputs)
# return outputs
generate_fn = gen_and_dec
if args.benchmark_type == "hf-pt2":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
if args.benchmark_type == "hf-ort":
# Rename log files per model component and turn profiling off to stop appending to log
new_prefix = new_logname[: -len(".json")]
old_logname = model.encoder.session.end_profiling()
new_logname = new_prefix + "-encoder.json"
if os.path.isfile(old_logname):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
old_logname = model.decoder.session.end_profiling()
new_logname = new_prefix + "-decoder.json"
if os.path.isfile(old_logname):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
old_logname = model.decoder_with_past.session.end_profiling()
new_logname = new_prefix + "-decoder-with-past.json"
if os.path.isfile(old_logname):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# PyTorch evaluations
logger.info("\nEvaluating PyTorch...")
time_fn(args, generate_fn, inputs)
predicted_ids, transcription = generate_fn(inputs)
logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
logger.info(f"Transcription: {transcription[0]}")
measure_fn(args, generate_fn, inputs)
def run_ort_inference(args, inputs, model):
def prepare_ort_inputs(inputs):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(inputs.keys())
missing_inputs = model_inputs - user_inputs
if len(missing_inputs):
logger.error(f"The following model inputs are missing: {missing_inputs}")
raise Exception("There are missing inputs to the model. Please add them and try again.")
# Remove unnecessary inputs from model inputs
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
for unnecessary_input in unnecessary_inputs:
logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
del inputs[unnecessary_input]
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding = model.io_binding()
for k, v in inputs.items():
io_binding.bind_cpu_input(k, v)
for output in model.get_outputs():
io_binding.bind_output(output.name)
return io_binding
return inputs
def with_io_binding(io_binding):
# Inference pass with IO binding
model.run_with_iobinding(io_binding)
return io_binding
def without_io_binding(inputs):
# Inference pass without IO binding
outputs = model.run(None, inputs)
return outputs
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
ort_inputs = prepare_ort_inputs(inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
# Turn profiling off to stop appending to log file
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# ORT evaluation
logger.info("\nEvaluating ONNX Runtime...")
time_fn(args, generate_fn, ort_inputs)
ort_outputs = generate_fn(ort_inputs)
if args.device != "cpu":
ort_outputs = ort_outputs.copy_outputs_to_cpu()
ort_outputs = ort_outputs[0]
if args.has_audio_stream:
# ONNX E2E model from Olive produces transcribed output
logger.info(f"Transcription: {ort_outputs[0][0]}")
else:
# convert_to_onnx model produces generated ids
logger.info(f"Generated token length: {len(ort_outputs[0][0])} tokens")
measure_fn(args, generate_fn, ort_inputs)
def run_inference(args, inputs, model):
if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}:
run_hf_inference(args, inputs, model)
elif args.benchmark_type == "ort":
run_ort_inference(args, inputs, model)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"]
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=True,
help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
)
parser.add_argument(
"-p",
"--precision",
type=str,
required=True,
default="fp32",
choices=["int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"--hf-pt-model-path",
type=str,
default="",
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
"--hf-ort-model-path",
type=str,
default="",
help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
)
parser.add_argument(
"--ort-model-path",
type=str,
default="",
help="Path to ONNX model",
)
# Args for running and evaluating the model
parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda", "rocm"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=10)
parser.add_argument("--seed", type=int, default=2)
# Optional args:
parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
# Args for decoding logic
# Required args:
parser.add_argument("--max-length", type=int, default=448)
parser.add_argument("--min-length", type=int, default=0)
parser.add_argument("--num-beams", type=int, default=1)
parser.add_argument("--num-return-sequences", type=int, default=1)
parser.add_argument("--length-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
# Optional args for E2E solution:
parser.add_argument(
"--decoder-input-ids",
type=str,
default="[]",
help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
)
parser.add_argument(
"--logits-processor",
type=int,
default=1,
help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.",
)
# Args for accessing detailed info
parser.add_argument("--profile", default=False, action="store_true")
parser.add_argument(
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
)
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Set runtime properties
if "ort" in args.benchmark_type:
args.execution_provider = f"{args.device.upper()}ExecutionProvider"
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
args.device = "cuda"
# Check that model paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`"
if args.benchmark_type == "ort":
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
# Convert decoder_input_ids string to list of ids
# (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
return args
def main():
args = parse_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
config = WhisperConfig.from_pretrained(args.model_name)
processor = WhisperProcessor.from_pretrained(args.model_name)
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
setattr(args, "processor", processor) # noqa: B010
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
setattr(args, "has_audio_stream", False) # noqa: B010
logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
# Measure cost to transcribe audio
model = get_model(args)
if args.benchmark_type == "ort":
# Check for optional inputs that could have been added during export
ort_model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
args.has_audio_stream = "audio_stream" in ort_model_inputs
setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
if args.decoder_input_ids == []:
args.decoder_input_ids = [config.decoder_start_token_id]
inputs = get_inputs(args)
run_inference(args, inputs, model)
if __name__ == "__main__":
main()