mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add Whisper scripts (#17043)
### Description This PR adds benchmark scripts for Whisper. It is a follow-up to [this PR](https://github.com/microsoft/onnxruntime/pull/17020) that adds the LLaMA scripts. ### Motivation and Context This PR enables benchmarking Whisper across various configurations.
This commit is contained in:
parent
5842144d98
commit
4b3477f171
5 changed files with 1138 additions and 14 deletions
|
|
@ -2,18 +2,37 @@
|
|||
|
||||
## Exporting Whisper with Beam Search
|
||||
|
||||
There are two ways to export Whisper with beam search (using Whisper tiny as an example).
|
||||
There are several ways to export Whisper with beam search (using Whisper tiny as an example).
|
||||
|
||||
### Option 1: from convert_to_onnx
|
||||
|
||||
Option 1: from source
|
||||
```
|
||||
# From source
|
||||
$ git clone https://github.com/microsoft/onnxruntime
|
||||
$ cd onnxruntime/onnxruntime/python/tools/transformers/models/whisper
|
||||
$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format
|
||||
$ cd onnxruntime/onnxruntime/python/tools/transformers/
|
||||
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format
|
||||
|
||||
# From wheel
|
||||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format
|
||||
```
|
||||
|
||||
Option 2: from wheel
|
||||
### Option 2: end-to-end model from [Olive](https://github.com/microsoft/Olive/tree/main/examples/whisper)
|
||||
|
||||
Please follow the [README instructions](https://github.com/microsoft/Olive/tree/main/examples/whisper#prerequisites) in Olive.
|
||||
|
||||
### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum)
|
||||
|
||||
Run the following Python code to export:
|
||||
|
||||
```
|
||||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format
|
||||
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
||||
|
||||
model_name = "openai/whisper-large-v2"
|
||||
model = ORTModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_name,
|
||||
export=True,
|
||||
)
|
||||
model.save_pretrained(model_name.split("/")[-1] + "-onnx")
|
||||
```
|
||||
|
||||
## Exporting + Optimizing + Quantizing Whisper with Beam Search
|
||||
|
|
@ -23,7 +42,7 @@ Here are some additional examples for exporting Whisper with beam search.
|
|||
Export with Forced Decoder Input Ids
|
||||
```
|
||||
# From source:
|
||||
$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids
|
||||
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids
|
||||
|
||||
# From wheel:
|
||||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids
|
||||
|
|
@ -32,7 +51,7 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w
|
|||
Export + Optimize for FP32
|
||||
```
|
||||
# From source:
|
||||
$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32
|
||||
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32
|
||||
|
||||
# From wheel:
|
||||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32
|
||||
|
|
@ -41,7 +60,7 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w
|
|||
Export + Optimize for FP16 and GPU
|
||||
```
|
||||
# From source:
|
||||
$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
|
||||
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
|
||||
|
||||
# From wheel:
|
||||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
|
||||
|
|
@ -50,8 +69,128 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w
|
|||
Export + Quantize for INT8
|
||||
```
|
||||
# From source:
|
||||
$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer
|
||||
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer
|
||||
|
||||
# From wheel:
|
||||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer
|
||||
```
|
||||
|
||||
## Benchmark Whisper
|
||||
|
||||
Here are some examples of how you can benchmark Whisper across various end-to-end (E2E) implementations.
|
||||
|
||||
Note: In the below examples, `PyTorch` refers to running in PyTorch without `torch.compile` and `PyTorch 2.0` refers to running in PyTorch with `torch.compile`.
|
||||
|
||||
### Variants
|
||||
|
||||
1. PyTorch (without `torch.compile`), FP32
|
||||
```
|
||||
python3 -m models.whisper.benchmark \
|
||||
--benchmark-type hf-pt \
|
||||
--audio-path 1272-141231-0002.mp3 \
|
||||
--model-name openai/whisper-large-v2 \
|
||||
--precision fp32 \
|
||||
--device cpu
|
||||
```
|
||||
|
||||
2. PyTorch 2.0 (with `torch.compile`), FP16
|
||||
```
|
||||
python3 -m models.whisper.benchmark \
|
||||
--benchmark-type hf-pt2 \
|
||||
--audio-path 1272-141231-0002.mp3 \
|
||||
--model-name openai/whisper-large-v2 \
|
||||
--precision fp16 \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
3. Optimum + ONNX Runtime, FP32, export via Optimum
|
||||
```
|
||||
python3 -m models.whisper.benchmark \
|
||||
--benchmark-type hf-ort \
|
||||
--audio-path 1272-141231-0002.mp3 \
|
||||
--model-name openai/whisper-large-v2 \
|
||||
--hf-ort-model-path ./whisper-large-v2-onnx/ \
|
||||
--precision fp32 \
|
||||
--device cpu
|
||||
```
|
||||
|
||||
4. ONNX Runtime, FP32, export via Olive or convert_to_onnx
|
||||
```
|
||||
python3 -m models.whisper.benchmark \
|
||||
--benchmark-type ort \
|
||||
--audio-path 1272-141231-0002.mp3 \
|
||||
--model-name openai/whisper-large-v2 \
|
||||
--ort-model-path ./wlarge-fp32/whisper-large-v2_beamsearch.onnx \
|
||||
--precision fp32 \
|
||||
--device cpu
|
||||
```
|
||||
|
||||
5. ONNX Runtime, FP16, export via Olive or convert_to_onnx
|
||||
```
|
||||
python3 -m models.whisper.benchmark \
|
||||
--benchmark-type ort \
|
||||
--audio-path 1272-141231-0002.mp3 \
|
||||
--model-name openai/whisper-large-v2 \
|
||||
--ort-model-path ./wlarge-fp32/whisper-large_all.onnx \
|
||||
--precision fp16 \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
6. ONNX Runtime, INT8, export via Olive or convert_to_onnx
|
||||
```
|
||||
python3 -m models.whisper.benchmark \
|
||||
--benchmark-type ort \
|
||||
--audio-path 1272-141231-0002.mp3 \
|
||||
--model-name openai/whisper-large-v2 \
|
||||
--ort-model-path ./wlarge-fp32/whisper-large-v2_all.onnx \
|
||||
--precision fp32 \
|
||||
--device cpu
|
||||
```
|
||||
|
||||
You can profile a variant by adding the `--profile` flag.
|
||||
|
||||
### Benchmark All
|
||||
|
||||
You can use `benchmark_all.py` to benchmark across various platforms and automatically store the results in a CSV file. Here is an example.
|
||||
|
||||
```
|
||||
python3 -m models.whisper.benchmark_all \
|
||||
--audio-path ./whisper-test-audios/ \
|
||||
--hf-ort-model-path ./whisper-large-v2-onnx/ \
|
||||
--ort-model-path ./wlarge-fp32/whisper-large-v2_all.onnx \
|
||||
--model-name openai/whisper-large-v2 \
|
||||
--precision fp32 \
|
||||
--device cpu
|
||||
```
|
||||
|
||||
### Benchmarking on NVIDIA A100
|
||||
|
||||
Here is a benchmark for an MP3 file with 20.7s of audio.
|
||||
|
||||
#### FP16
|
||||
|
||||
| Engine | Size | Per-Token Latency | Real-Time Factor |
|
||||
| ------------- | -------- | ----------------- | ---------------- |
|
||||
| PyTorch | Tiny | 4.697 ms/token | 0.004697 |
|
||||
| PyTorch 2.0 | Tiny | 3.406 ms/token | 0.003406 |
|
||||
| ONNX Runtime | Tiny | 0.746 ms/token | 0.000746 |
|
||||
| PyTorch | Medium | 17.837 ms/token | 0.017387 |
|
||||
| PyTorch 2.0 | Medium | 18.124 ms/token | 0.018124 |
|
||||
| ONNX Runtime | Medium | 3.894 ms/token | 0.003894 |
|
||||
| PyTorch | Large v2 | 23.470 ms/token | 0.023470 |
|
||||
| PyTorch 2.0 | Large v2 | 23.146 ms/token | 0.023146 |
|
||||
| ONNX Runtime | Large v2 | 6.262 ms/token | 0.006262 |
|
||||
|
||||
#### FP32
|
||||
|
||||
| Engine | Size | Per-Token Latency | Real-Time Factor |
|
||||
| ------------- | -------- | ----------------- | ---------------- |
|
||||
| PyTorch | Tiny | 6.220 ms/token | 0.006220 |
|
||||
| PyTorch 2.0 | Tiny | 3.944 ms/token | 0.003944 |
|
||||
| ONNX Runtime | Tiny | 1.545 ms/token | 0.001545 |
|
||||
| PyTorch | Medium | 19.093 ms/token | 0.019093 |
|
||||
| PyTorch 2.0 | Medium | 20.459 ms/token | 0.020459 |
|
||||
| ONNX Runtime | Medium | 9.440 ms/token | 0.009440 |
|
||||
| PyTorch | Large v2 | 25.844 ms/token | 0.025844 |
|
||||
| PyTorch 2.0 | Large v2 | 26.397 ms/token | 0.026397 |
|
||||
| ONNX Runtime | Large v2 | 7.492 ms/token | 0.007492 |
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os.path
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,550 @@
|
|||
import argparse
|
||||
import ast
|
||||
import datetime
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import whisper
|
||||
from benchmark_helper import setup_logger
|
||||
from onnxruntime_extensions import get_library_path
|
||||
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from tqdm import trange
|
||||
from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.transformers.benchmark_helper import measure_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_inputs(args: argparse.Namespace):
|
||||
if args.benchmark_type not in {"hf-pt", "hf-pt2", "hf-ort", "ort"}:
|
||||
raise Exception("Unable to auto-detect inputs for provided model")
|
||||
|
||||
def load_via_ffmpeg():
|
||||
audio = whisper.load_audio(args.audio_path)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
return audio
|
||||
|
||||
def load_via_numpy():
|
||||
with open(args.audio_path, "rb") as f:
|
||||
audio = np.asarray(list(f.read()), dtype=np.uint8)
|
||||
audio = np.array([audio])
|
||||
return audio
|
||||
|
||||
inputs = {
|
||||
"max_length": args.max_length,
|
||||
"min_length": args.min_length,
|
||||
"num_beams": args.num_beams,
|
||||
"num_return_sequences": args.num_return_sequences,
|
||||
"length_penalty": args.length_penalty,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
}
|
||||
if args.benchmark_type == "ort":
|
||||
# convert_to_onnx export or ONNX E2E solution created by Olive
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
|
||||
if args.has_decoder_input_ids:
|
||||
inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
|
||||
if args.has_logits_processor:
|
||||
inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
|
||||
|
||||
# Measure time taken to load audio file
|
||||
logger.info(f"Load audio: {args.audio_path}")
|
||||
load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
|
||||
time_fn(args, load_audio_fn, args.has_audio_stream)
|
||||
audio_data = load_audio_fn(args.has_audio_stream)
|
||||
|
||||
if args.has_audio_stream:
|
||||
# ONNX E2E solution created by Olive
|
||||
inputs["audio_stream"] = audio_data
|
||||
return inputs
|
||||
|
||||
# Measure time taken to get input features
|
||||
logger.info("Feature extraction: ")
|
||||
return_type = "np" if args.benchmark_type == "ort" else "pt"
|
||||
processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
|
||||
[audio], return_tensors=return_type, sampling_rate=args.sampling_rate
|
||||
).input_features
|
||||
time_fn(args, processor_fn, audio_data)
|
||||
input_features = processor_fn(audio_data)
|
||||
|
||||
if args.benchmark_type == "ort":
|
||||
# convert_to_onnx export
|
||||
inputs["input_features"] = input_features
|
||||
return inputs
|
||||
|
||||
inputs["inputs"] = input_features.to(
|
||||
dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
|
||||
)
|
||||
inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
|
||||
inputs["early_stopping"] = True
|
||||
inputs["use_cache"] = True
|
||||
|
||||
if args.decoder_input_ids:
|
||||
inputs["forced_decoder_ids"] = args.decoder_input_ids
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def get_model(args: argparse.Namespace):
|
||||
model, sess_options = None, None
|
||||
start_time, end_time = None, None
|
||||
|
||||
# There are multiple sources that the model could come from:
|
||||
# 1) Benchmark Whisper from Hugging Face
|
||||
# 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
|
||||
# 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
|
||||
|
||||
if args.benchmark_type in {"hf-pt", "hf-pt2"}:
|
||||
source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
|
||||
start_time = time.time()
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
source,
|
||||
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
|
||||
use_cache=True,
|
||||
).to(args.target_device)
|
||||
end_time = time.time()
|
||||
|
||||
if args.benchmark_type == "hf-pt2":
|
||||
model = torch.compile(model)
|
||||
|
||||
elif args.benchmark_type in {"hf-ort", "ort"}:
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.enable_profiling = args.profile
|
||||
sess_options.register_custom_ops_library(get_library_path())
|
||||
if args.verbose:
|
||||
sess_options.log_verbosity_level = 1
|
||||
sess_options.log_severity_level = 1
|
||||
|
||||
else:
|
||||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Optimum export
|
||||
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
|
||||
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
|
||||
|
||||
start_time = time.time()
|
||||
model = ORTModelForSpeechSeq2Seq.from_pretrained(
|
||||
args.hf_ort_model_path,
|
||||
use_io_binding=(args.device != "cpu"),
|
||||
provider=provider,
|
||||
provider_options=provider_options,
|
||||
session_options=sess_options,
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
if args.benchmark_type == "ort":
|
||||
# convert_to_onnx.py export
|
||||
logger.info(f"Loading model from {args.ort_model_path}")
|
||||
start_time = time.time()
|
||||
model = ort.InferenceSession(
|
||||
args.ort_model_path,
|
||||
sess_options,
|
||||
providers=[args.execution_provider],
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(f"Loaded model in {end_time - start_time} s")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def time_fn(args, fn, inputs):
|
||||
# Warm up
|
||||
warmup_range = (
|
||||
range(args.warmup_runs)
|
||||
if args.benchmark_type == "ort"
|
||||
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
|
||||
)
|
||||
|
||||
if args.verbose:
|
||||
outputs = fn(inputs)
|
||||
logger.info(outputs)
|
||||
|
||||
for _ in warmup_range:
|
||||
fn(inputs)
|
||||
|
||||
# Benchmark
|
||||
if args.device != "cpu":
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
bench_range = (
|
||||
range(args.num_runs)
|
||||
if args.benchmark_type == "ort"
|
||||
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
|
||||
)
|
||||
for _ in bench_range:
|
||||
fn(inputs)
|
||||
|
||||
if args.device != "cpu":
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
# Newline print after trange in order to print metrics on new lines without progress bar on same line
|
||||
if args.benchmark_type != "ort":
|
||||
logger.info("")
|
||||
|
||||
batch_size = 1
|
||||
latency = (end_time - start_time) / args.num_runs
|
||||
throughput = batch_size / latency
|
||||
|
||||
logger.info(f"Latency: {latency} s")
|
||||
logger.info(f"Throughput: {throughput} qps")
|
||||
return
|
||||
|
||||
|
||||
def profile_fn(args, fn, inputs, inputs_type):
|
||||
# Filename prefix format:
|
||||
# "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
|
||||
prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
|
||||
filename = None
|
||||
|
||||
if args.benchmark_type in {"hf-pt", "hf-pt2"}:
|
||||
# Profile PyTorch kernels
|
||||
with profile( # noqa: SIM117
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("model_inference"):
|
||||
fn(inputs)
|
||||
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
|
||||
|
||||
filename = os.path.join(args.log_folder, f"{prefix}.log")
|
||||
with open(filename, "w") as f:
|
||||
f.write(prof_data)
|
||||
|
||||
else:
|
||||
# Profile ORT kernels
|
||||
fn(inputs)
|
||||
|
||||
# Set new log name for ORT profile log generated
|
||||
filename = f"{prefix}.json"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def measure_fn(args, fn, inputs):
|
||||
# Measure CPU usage
|
||||
pid = os.getpid()
|
||||
process = psutil.Process(pid)
|
||||
process.cpu_percent(interval=0.1)
|
||||
|
||||
fn(inputs)
|
||||
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
|
||||
|
||||
# Measure memory usage
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
|
||||
|
||||
# Flush output so memory usage is printed
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def run_hf_inference(args, inputs, model):
|
||||
# Inference steps to measure
|
||||
def get_pred_ids(inputs):
|
||||
# Inference pass with predicted token ids generation
|
||||
predicted_ids = model.generate(**inputs)
|
||||
return predicted_ids, [""]
|
||||
|
||||
def gen_and_dec(inputs):
|
||||
# Inference pass with generation and decoding
|
||||
predicted_ids = get_pred_ids(inputs)
|
||||
transcription = []
|
||||
for _ in range(args.num_return_sequences):
|
||||
transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
|
||||
return predicted_ids, transcription
|
||||
|
||||
# Examples of other inference steps that can be measured:
|
||||
# To use, uncomment the function and assign it to `generate_fn`
|
||||
|
||||
# def get_logits(inputs):
|
||||
# # Inference pass without decoding
|
||||
# outputs = model(**inputs)
|
||||
# return outputs
|
||||
|
||||
generate_fn = gen_and_dec
|
||||
|
||||
if args.benchmark_type == "hf-pt2":
|
||||
# Run forward pass once with each set of inputs to process through Dynamo
|
||||
generate_fn(inputs)
|
||||
|
||||
if args.profile:
|
||||
new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Rename log files per model component and turn profiling off to stop appending to log
|
||||
new_prefix = new_logname[: -len(".json")]
|
||||
|
||||
old_logname = model.encoder.session.end_profiling()
|
||||
new_logname = new_prefix + "-encoder.json"
|
||||
if os.path.isfile(old_logname):
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
old_logname = model.decoder.session.end_profiling()
|
||||
new_logname = new_prefix + "-decoder.json"
|
||||
if os.path.isfile(old_logname):
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
old_logname = model.decoder_with_past.session.end_profiling()
|
||||
new_logname = new_prefix + "-decoder-with-past.json"
|
||||
if os.path.isfile(old_logname):
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
return
|
||||
|
||||
# PyTorch evaluations
|
||||
logger.info("\nEvaluating PyTorch...")
|
||||
time_fn(args, generate_fn, inputs)
|
||||
predicted_ids, transcription = generate_fn(inputs)
|
||||
logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
|
||||
logger.info(f"Transcription: {transcription[0]}")
|
||||
measure_fn(args, generate_fn, inputs)
|
||||
|
||||
|
||||
def run_ort_inference(args, inputs, model):
|
||||
def prepare_ort_inputs(inputs):
|
||||
# Check that all model inputs will be provided
|
||||
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
|
||||
user_inputs = set(inputs.keys())
|
||||
missing_inputs = model_inputs - user_inputs
|
||||
if len(missing_inputs):
|
||||
logger.error(f"The following model inputs are missing: {missing_inputs}")
|
||||
raise Exception("There are missing inputs to the model. Please add them and try again.")
|
||||
|
||||
# Remove unnecessary inputs from model inputs
|
||||
unnecessary_inputs = user_inputs - model_inputs
|
||||
if len(unnecessary_inputs):
|
||||
for unnecessary_input in unnecessary_inputs:
|
||||
logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
|
||||
del inputs[unnecessary_input]
|
||||
|
||||
# Add IO bindings for non-CPU execution providers
|
||||
if args.device != "cpu":
|
||||
io_binding = model.io_binding()
|
||||
for k, v in inputs.items():
|
||||
io_binding.bind_cpu_input(k, v)
|
||||
for output in model.get_outputs():
|
||||
io_binding.bind_output(output.name)
|
||||
return io_binding
|
||||
|
||||
return inputs
|
||||
|
||||
def with_io_binding(io_binding):
|
||||
# Inference pass with IO binding
|
||||
model.run_with_iobinding(io_binding)
|
||||
return io_binding
|
||||
|
||||
def without_io_binding(inputs):
|
||||
# Inference pass without IO binding
|
||||
outputs = model.run(None, inputs)
|
||||
return outputs
|
||||
|
||||
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
|
||||
ort_inputs = prepare_ort_inputs(inputs)
|
||||
|
||||
if args.profile:
|
||||
new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
|
||||
|
||||
# Turn profiling off to stop appending to log file
|
||||
old_logname = model.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
return
|
||||
|
||||
# ORT evaluation
|
||||
logger.info("\nEvaluating ONNX Runtime...")
|
||||
time_fn(args, generate_fn, ort_inputs)
|
||||
ort_outputs = generate_fn(ort_inputs)
|
||||
if args.device != "cpu":
|
||||
ort_outputs = ort_outputs.copy_outputs_to_cpu()
|
||||
ort_outputs = ort_outputs[0]
|
||||
|
||||
if args.has_audio_stream:
|
||||
# ONNX E2E model from Olive produces transcribed output
|
||||
logger.info(f"Transcription: {ort_outputs[0][0]}")
|
||||
else:
|
||||
# convert_to_onnx model produces generated ids
|
||||
logger.info(f"Generated token length: {len(ort_outputs[0][0])} tokens")
|
||||
|
||||
measure_fn(args, generate_fn, ort_inputs)
|
||||
|
||||
|
||||
def run_inference(args, inputs, model):
|
||||
if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}:
|
||||
run_hf_inference(args, inputs, model)
|
||||
elif args.benchmark_type == "ort":
|
||||
run_ort_inference(args, inputs, model)
|
||||
else:
|
||||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
required=True,
|
||||
default="fp32",
|
||||
choices=["int8", "fp16", "fp32"],
|
||||
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-pt-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-ort-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ort-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to ONNX model",
|
||||
)
|
||||
|
||||
# Args for running and evaluating the model
|
||||
parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
)
|
||||
parser.add_argument("-id", "--device-id", type=int, default=0)
|
||||
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
|
||||
parser.add_argument("-n", "--num-runs", type=int, default=10)
|
||||
parser.add_argument("--seed", type=int, default=2)
|
||||
|
||||
# Optional args:
|
||||
parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
|
||||
|
||||
# Args for decoding logic
|
||||
# Required args:
|
||||
parser.add_argument("--max-length", type=int, default=448)
|
||||
parser.add_argument("--min-length", type=int, default=0)
|
||||
parser.add_argument("--num-beams", type=int, default=1)
|
||||
parser.add_argument("--num-return-sequences", type=int, default=1)
|
||||
parser.add_argument("--length-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
|
||||
|
||||
# Optional args for E2E solution:
|
||||
parser.add_argument(
|
||||
"--decoder-input-ids",
|
||||
type=str,
|
||||
default="[]",
|
||||
help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logits-processor",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.",
|
||||
)
|
||||
|
||||
# Args for accessing detailed info
|
||||
parser.add_argument("--profile", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
|
||||
)
|
||||
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
|
||||
parser.add_argument("--verbose", default=False, action="store_true")
|
||||
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set seed properties
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Set runtime properties
|
||||
if "ort" in args.benchmark_type:
|
||||
args.execution_provider = f"{args.device.upper()}ExecutionProvider"
|
||||
if args.execution_provider == "CUDAExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
elif args.execution_provider == "ROCMExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
args.device = "cuda"
|
||||
|
||||
# Check that model paths have been specified for any benchmarking with ORT
|
||||
if args.benchmark_type == "hf-ort":
|
||||
assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`"
|
||||
if args.benchmark_type == "ort":
|
||||
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
|
||||
|
||||
# Convert decoder_input_ids string to list of ids
|
||||
# (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
|
||||
args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
setup_logger(args.verbose)
|
||||
logger.info(args.__dict__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
config = WhisperConfig.from_pretrained(args.model_name)
|
||||
processor = WhisperProcessor.from_pretrained(args.model_name)
|
||||
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
|
||||
use_fp16 = args.precision == "fp16"
|
||||
|
||||
setattr(args, "processor", processor) # noqa: B010
|
||||
setattr(args, "target_device", target_device) # noqa: B010
|
||||
setattr(args, "use_fp16", use_fp16) # noqa: B010
|
||||
setattr(args, "has_audio_stream", False) # noqa: B010
|
||||
|
||||
logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
|
||||
|
||||
# Measure cost to transcribe audio
|
||||
model = get_model(args)
|
||||
if args.benchmark_type == "ort":
|
||||
# Check for optional inputs that could have been added during export
|
||||
ort_model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
|
||||
args.has_audio_stream = "audio_stream" in ort_model_inputs
|
||||
setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
|
||||
setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
|
||||
|
||||
if args.decoder_input_ids == []:
|
||||
args.decoder_input_ids = [config.decoder_start_token_id]
|
||||
|
||||
inputs = get_inputs(args)
|
||||
run_inference(args, inputs, model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,432 @@
|
|||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from transformers import WhisperConfig, WhisperProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--audio-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to folder of audio files for E2E evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--language",
|
||||
default=None,
|
||||
help="Language of audio file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--task",
|
||||
default=None,
|
||||
choices=["transcribe", "translate"],
|
||||
help="Task to complete",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--warmup-runs",
|
||||
type=int,
|
||||
default=5,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=10,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-ort-model-path",
|
||||
type=str,
|
||||
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ort-model-path",
|
||||
type=str,
|
||||
help="Path to ONNX model for ORT benchmarking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name in Hugging Face (e.g. openai/whisper-large-v2)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["int8", "fp16", "fp32"],
|
||||
help="Precision to run model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
help="Device to benchmark models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device-id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="GPU device ID",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Print detailed logs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of mins to attempt the benchmark before moving on",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
|
||||
log_folder_name = f"./{args.model_size}-{args.precision}"
|
||||
setattr(args, "log_folder", log_folder_name) # noqa: B010
|
||||
os.makedirs(args.log_folder, exist_ok=True)
|
||||
|
||||
# Convert timeout value to secs
|
||||
args.timeout *= 60
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def process_log_file(device_id, log_file, base_results):
|
||||
entries = []
|
||||
|
||||
# Detect steps in speech pipeline
|
||||
step = None
|
||||
load_audio_pattern = "Load audio: "
|
||||
feat_ext_pattern = "Feature extraction: "
|
||||
pytorch_pattern = "Evaluating PyTorch..."
|
||||
onnxruntime_pattern = "Evaluating ONNX Runtime..."
|
||||
|
||||
load_audio_latency_s, load_audio_throughput_s = None, None
|
||||
feat_ext_latency_s, feat_ext_throughput_s = None, None
|
||||
latency_s, per_token_latency_s, per_token_latency_ms = None, None, None
|
||||
throughput, memory = None, None
|
||||
|
||||
# Detect metrics
|
||||
latency_pattern = "Latency: "
|
||||
throughput_pattern = "Throughput: "
|
||||
token_length_pattern = "Generated token length: "
|
||||
memory_pattern = "peak="
|
||||
|
||||
with open(log_file) as f:
|
||||
for input_line in f:
|
||||
line = input_line.replace("\n", "")
|
||||
|
||||
# Get step in speech recognition pipeline
|
||||
if load_audio_pattern in line:
|
||||
step = "load-audio"
|
||||
elif feat_ext_pattern in line:
|
||||
step = "feature-extraction"
|
||||
elif pytorch_pattern in line or onnxruntime_pattern in line:
|
||||
step = "process"
|
||||
|
||||
# Check metrics
|
||||
if latency_pattern in line:
|
||||
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
|
||||
elif throughput_pattern in line:
|
||||
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
|
||||
if step == "load-audio":
|
||||
load_audio_latency_s, load_audio_throughput_s = latency_s, throughput
|
||||
step = None
|
||||
if step == "feature-extraction":
|
||||
feat_ext_latency_s, feat_ext_throughput_s = latency_s, throughput
|
||||
step = None
|
||||
elif token_length_pattern in line:
|
||||
token_length = int(line[len(token_length_pattern) : line.rfind(" ")])
|
||||
per_token_latency_s = latency_s / token_length
|
||||
per_token_latency_ms = per_token_latency_s * 1000
|
||||
elif memory_pattern in line:
|
||||
if "CPU" in line:
|
||||
# Example format for log entry:
|
||||
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
|
||||
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
|
||||
else:
|
||||
# Example format for log entry:
|
||||
# GPU memory usage: before=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1638.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}, peak=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1780.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}]
|
||||
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
|
||||
usage = json.loads(peak)[device_id]["max_used_MB"]
|
||||
memory = float(usage) / 1000
|
||||
|
||||
# Calculate real-time factor (RTF):
|
||||
# RTF = total latency / audio duration
|
||||
total_latency = (
|
||||
(load_audio_latency_s if load_audio_latency_s else 0)
|
||||
+ (feat_ext_latency_s if feat_ext_latency_s else 0)
|
||||
+ (latency_s if latency_s else 0)
|
||||
)
|
||||
audio_duration = base_results[-1]
|
||||
rtf = (total_latency / audio_duration) if audio_duration else -1
|
||||
logger.info(f"Total latency: {total_latency} s")
|
||||
logger.info(f"Audio duration: {audio_duration} s")
|
||||
logger.info(f"Real-time factor: {rtf}")
|
||||
|
||||
# Append log entry to list of entries
|
||||
entry = base_results + [ # noqa: RUF005
|
||||
token_length,
|
||||
load_audio_latency_s,
|
||||
load_audio_throughput_s,
|
||||
feat_ext_latency_s if feat_ext_latency_s else -1,
|
||||
feat_ext_throughput_s if feat_ext_throughput_s else -1,
|
||||
latency_s,
|
||||
per_token_latency_ms,
|
||||
throughput,
|
||||
memory,
|
||||
rtf,
|
||||
]
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def save_results(results, filename):
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(
|
||||
results,
|
||||
columns=[
|
||||
"Engine",
|
||||
"Precision",
|
||||
"Device",
|
||||
"Audio File",
|
||||
"Duration (s)",
|
||||
"Token Length",
|
||||
"Load Audio Latency (s)",
|
||||
"Load Audio Throughput (qps)",
|
||||
"Feature Extractor Latency (s)",
|
||||
"Feature Extractor Throughput (qps)",
|
||||
"Latency (s)",
|
||||
"Per Token Latency (ms/token)",
|
||||
"Throughput (qps)",
|
||||
"Memory (GB)",
|
||||
"Real Time Factor (RTF)",
|
||||
],
|
||||
)
|
||||
|
||||
# Set column types
|
||||
df["Duration (s)"] = df["Duration (s)"].astype("float")
|
||||
df["Token Length"] = df["Token Length"].astype("int")
|
||||
df["Load Audio Latency (s)"] = df["Load Audio Latency (s)"].astype("float")
|
||||
df["Load Audio Throughput (qps)"] = df["Load Audio Throughput (qps)"].astype("float")
|
||||
df["Feature Extractor Latency (s)"] = df["Feature Extractor Latency (s)"].astype("float")
|
||||
df["Feature Extractor Throughput (qps)"] = df["Feature Extractor Throughput (qps)"].astype("float")
|
||||
df["Latency (s)"] = df["Latency (s)"].astype("float")
|
||||
df["Per Token Latency (ms/token)"] = df["Per Token Latency (ms/token)"].astype("float")
|
||||
df["Throughput (qps)"] = df["Throughput (qps)"].astype("float")
|
||||
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
|
||||
df["Real Time Factor (RTF)"] = df["Real Time Factor (RTF)"].astype("float")
|
||||
|
||||
df.to_csv(filename, index=False)
|
||||
logger.info(f"Results saved in {filename}!")
|
||||
|
||||
|
||||
def benchmark(args, benchmark_cmd, engine, audio_file, duration):
|
||||
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
|
||||
log_path = os.path.join(args.log_folder, log_filename)
|
||||
with open(log_path, "w") as log_file:
|
||||
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
|
||||
try:
|
||||
process.wait(args.timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
# Create entries for csv
|
||||
logger.info("Gathering data from log files...")
|
||||
base_results = [engine, args.precision, args.device, audio_file, duration]
|
||||
results = process_log_file(args.device_id, log_path, base_results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
setup_logger(args.verbose)
|
||||
logger.info(args.__dict__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
config = WhisperConfig.from_pretrained(args.model_name)
|
||||
processor = WhisperProcessor.from_pretrained(args.model_name)
|
||||
|
||||
# Calculate forced decoder input ids
|
||||
hf_forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
|
||||
ort_forced_decoder_ids = [config.decoder_start_token_id] + list( # noqa: RUF005
|
||||
map(lambda token_id: token_id[1], hf_forced_decoder_ids)
|
||||
)
|
||||
hf_decoder_input_ids_cmd = (
|
||||
["--decoder-input-ids", str(hf_forced_decoder_ids)] if args.language and args.task else []
|
||||
)
|
||||
ort_decoder_input_ids_cmd = (
|
||||
["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else []
|
||||
)
|
||||
|
||||
all_results = []
|
||||
for audio_file in os.listdir(args.audio_path):
|
||||
audio_path = os.path.join(args.audio_path, audio_file)
|
||||
try:
|
||||
duration = librosa.get_duration(path=audio_path)
|
||||
except Exception as e:
|
||||
duration = -1
|
||||
logger.warning(f"An error occurred while trying to calculate the audio duration: {e}", exc_info=True)
|
||||
logger.warning(
|
||||
f"If you get an error that says:\n\tsoundfile.LibsndfileError: Error opening '{audio_file}': File contains data in an unknown format.\nyou may not have installed `ffmpeg` in addition to installing `librosa`."
|
||||
)
|
||||
logger.info(f"Testing {audio_path}...")
|
||||
|
||||
# Benchmark PyTorch without torch.compile
|
||||
benchmark_cmd = [ # noqa: RUF005
|
||||
"python3",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"hf-pt",
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
] + hf_decoder_input_ids_cmd
|
||||
logger.info("Benchmark PyTorch without torch.compile")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch", audio_file, duration)
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark PyTorch with torch.compile
|
||||
benchmark_cmd = [ # noqa: RUF005
|
||||
"python3",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"hf-pt2",
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
] + hf_decoder_input_ids_cmd
|
||||
logger.info("Benchmark PyTorch with torch.compile")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch-2", audio_file, duration)
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark Optimum + ONNX Runtime
|
||||
if args.hf_ort_model_path:
|
||||
benchmark_cmd = [ # noqa: RUF005
|
||||
"python3",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"hf-ort",
|
||||
"--hf-ort-model-path",
|
||||
args.hf_ort_model_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
] + hf_decoder_input_ids_cmd
|
||||
logger.info("Benchmark Optimum + ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch-ort", audio_file, duration)
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark ONNX Runtime
|
||||
if args.ort_model_path:
|
||||
benchmark_cmd = [ # noqa: RUF005
|
||||
"python3",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"ort",
|
||||
"--ort-model-path",
|
||||
args.ort_model_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
] + ort_decoder_input_ids_cmd
|
||||
logger.info("Benchmark ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration)
|
||||
all_results.extend(results)
|
||||
|
||||
csv_file = f"{args.model_size}-{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
|
||||
save_results(all_results, os.path.join(args.log_folder, csv_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -180,22 +180,25 @@ def parse_arguments(argv=None):
|
|||
"--quantize_embedding_layer",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Produce beam search model with chained encdecinit and decoder.",
|
||||
help="Quantize MatMul, GEMM, and Gather.",
|
||||
)
|
||||
parser.set_defaults(quantize_embedding_layer=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--quantize_per_channel",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Produce beam search model with chained encdecinit and decoder.",
|
||||
help="Quantize weights per each channel.",
|
||||
)
|
||||
parser.set_defaults(quantize_per_channel=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--quantize_reduce_range",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Produce beam search model with chained encdecinit and decoder.",
|
||||
help="Quantize weights with 7 bits.",
|
||||
)
|
||||
parser.set_defaults(quantize_reduce_range=False)
|
||||
|
||||
parser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="default to 0")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue