mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
[ROCm] Update whisper benchmark script (#17391)
- update whisper benchmark for ROCm EP.
This commit is contained in:
parent
c969237321
commit
af14ae8050
3 changed files with 196 additions and 89 deletions
|
|
@ -8,7 +8,10 @@ import csv
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import timeit
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
|
@ -439,68 +442,127 @@ def get_gpu_info() -> Optional[List[Dict[str, Any]]]:
|
|||
return None
|
||||
|
||||
|
||||
def measure_memory(is_gpu, func):
|
||||
class MemoryMonitor:
|
||||
def __init__(self, keep_measuring=True):
|
||||
self.keep_measuring = keep_measuring
|
||||
class MemoryMonitor(ABC):
|
||||
def __init__(self, keep_measuring=True):
|
||||
self.keep_measuring = keep_measuring
|
||||
|
||||
def measure_cpu_usage(self):
|
||||
import psutil
|
||||
def measure_cpu_usage(self):
|
||||
import psutil
|
||||
|
||||
max_usage = 0
|
||||
max_usage = 0
|
||||
while True:
|
||||
max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
|
||||
sleep(0.005) # 5ms
|
||||
if not self.keep_measuring:
|
||||
break
|
||||
return max_usage
|
||||
|
||||
@abstractmethod
|
||||
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CudaMemoryMonitor(MemoryMonitor):
|
||||
def __init__(self, keep_measuring=True):
|
||||
super().__init__(keep_measuring)
|
||||
|
||||
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
|
||||
from py3nvml.py3nvml import (
|
||||
NVMLError,
|
||||
nvmlDeviceGetCount,
|
||||
nvmlDeviceGetHandleByIndex,
|
||||
nvmlDeviceGetMemoryInfo,
|
||||
nvmlDeviceGetName,
|
||||
nvmlInit,
|
||||
nvmlShutdown,
|
||||
)
|
||||
|
||||
max_gpu_usage = []
|
||||
gpu_name = []
|
||||
try:
|
||||
nvmlInit()
|
||||
device_count = nvmlDeviceGetCount()
|
||||
if not isinstance(device_count, int):
|
||||
logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
|
||||
return None
|
||||
|
||||
max_gpu_usage = [0 for i in range(device_count)]
|
||||
gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
|
||||
while True:
|
||||
max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
|
||||
for i in range(device_count):
|
||||
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
|
||||
if isinstance(info, str):
|
||||
logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
|
||||
return None
|
||||
max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
|
||||
sleep(0.005) # 5ms
|
||||
if not self.keep_measuring:
|
||||
break
|
||||
return max_usage
|
||||
nvmlShutdown()
|
||||
return [
|
||||
{
|
||||
"device_id": i,
|
||||
"name": gpu_name[i],
|
||||
"max_used_MB": max_gpu_usage[i],
|
||||
}
|
||||
for i in range(device_count)
|
||||
]
|
||||
except NVMLError as error:
|
||||
logger.error("Error fetching GPU information using nvml: %s", error)
|
||||
return None
|
||||
|
||||
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
|
||||
from py3nvml.py3nvml import (
|
||||
NVMLError,
|
||||
nvmlDeviceGetCount,
|
||||
nvmlDeviceGetHandleByIndex,
|
||||
nvmlDeviceGetMemoryInfo,
|
||||
nvmlDeviceGetName,
|
||||
nvmlInit,
|
||||
nvmlShutdown,
|
||||
)
|
||||
|
||||
max_gpu_usage = []
|
||||
gpu_name = []
|
||||
try:
|
||||
nvmlInit()
|
||||
device_count = nvmlDeviceGetCount()
|
||||
if not isinstance(device_count, int):
|
||||
logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
|
||||
return None
|
||||
class RocmMemoryMonitor(MemoryMonitor):
|
||||
def __init__(self, keep_measuring=True):
|
||||
super().__init__(keep_measuring)
|
||||
rocm_smi_path = "/opt/rocm/libexec/rocm_smi"
|
||||
if os.path.exists(rocm_smi_path):
|
||||
if rocm_smi_path not in sys.path:
|
||||
sys.path.append(rocm_smi_path)
|
||||
try:
|
||||
import rocm_smi
|
||||
|
||||
max_gpu_usage = [0 for i in range(device_count)]
|
||||
gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
|
||||
while True:
|
||||
for i in range(device_count):
|
||||
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
|
||||
if isinstance(info, str):
|
||||
logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
|
||||
return None
|
||||
max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
|
||||
sleep(0.005) # 5ms
|
||||
if not self.keep_measuring:
|
||||
break
|
||||
nvmlShutdown()
|
||||
return [
|
||||
{
|
||||
"device_id": i,
|
||||
"name": gpu_name[i],
|
||||
"max_used_MB": max_gpu_usage[i],
|
||||
}
|
||||
for i in range(device_count)
|
||||
]
|
||||
except NVMLError as error:
|
||||
logger.error("Error fetching GPU information using nvml: %s", error)
|
||||
return None
|
||||
self.rocm_smi = rocm_smi
|
||||
self.rocm_smi.initializeRsmi()
|
||||
except ImportError:
|
||||
self.rocm_smi = None
|
||||
|
||||
monitor = MemoryMonitor(False)
|
||||
def get_used_memory(self, dev):
|
||||
if self.rocm_smi is None:
|
||||
return -1
|
||||
return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024
|
||||
|
||||
def measure_gpu_usage(self):
|
||||
if self.rocm_smi is None:
|
||||
return None
|
||||
|
||||
device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0
|
||||
max_gpu_usage = [0 for i in range(device_count)]
|
||||
gpu_name = [f"GPU{i}" for i in range(device_count)]
|
||||
while True:
|
||||
for i in range(device_count):
|
||||
max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i))
|
||||
time.sleep(0.005) # 2ms
|
||||
if not self.keep_measuring:
|
||||
break
|
||||
return [
|
||||
{
|
||||
"device_id": i,
|
||||
"name": gpu_name[i],
|
||||
"max_used_MB": max_gpu_usage[i],
|
||||
}
|
||||
for i in range(device_count)
|
||||
]
|
||||
|
||||
|
||||
def measure_memory(is_gpu, func, monitor_type="cuda"):
|
||||
memory_monitor_type = None
|
||||
if monitor_type == "rocm":
|
||||
memory_monitor_type = RocmMemoryMonitor
|
||||
else:
|
||||
memory_monitor_type = CudaMemoryMonitor
|
||||
|
||||
monitor = memory_monitor_type(False)
|
||||
|
||||
if is_gpu:
|
||||
memory_before_test = monitor.measure_gpu_usage()
|
||||
|
|
@ -508,7 +570,7 @@ def measure_memory(is_gpu, func):
|
|||
return None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
monitor = MemoryMonitor()
|
||||
monitor = memory_monitor_type()
|
||||
mem_thread = executor.submit(monitor.measure_gpu_usage)
|
||||
try:
|
||||
fn_thread = executor.submit(func)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
import psutil
|
||||
import torch
|
||||
import whisper
|
||||
from benchmark_helper import setup_logger
|
||||
from benchmark_helper import measure_memory, setup_logger
|
||||
from onnxruntime_extensions import get_library_path
|
||||
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
|
|
@ -19,7 +19,6 @@ from tqdm import trange
|
|||
from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.transformers.benchmark_helper import measure_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -123,6 +122,9 @@ def get_model(args: argparse.Namespace):
|
|||
if args.verbose:
|
||||
sess_options.log_verbosity_level = 1
|
||||
sess_options.log_severity_level = 1
|
||||
if args.tune:
|
||||
ort.set_default_logger_severity(0)
|
||||
ort.set_default_logger_verbosity(0)
|
||||
|
||||
else:
|
||||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
|
@ -159,6 +161,9 @@ def get_model(args: argparse.Namespace):
|
|||
|
||||
|
||||
def time_fn(args, fn, inputs):
|
||||
warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
|
||||
benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
|
||||
|
||||
# Warm up
|
||||
warmup_range = (
|
||||
range(args.warmup_runs)
|
||||
|
|
@ -167,11 +172,11 @@ def time_fn(args, fn, inputs):
|
|||
)
|
||||
|
||||
if args.verbose:
|
||||
outputs = fn(inputs)
|
||||
outputs = fn(warmup_inputs)
|
||||
logger.info(outputs)
|
||||
|
||||
for _ in warmup_range:
|
||||
fn(inputs)
|
||||
fn(warmup_inputs)
|
||||
|
||||
# Benchmark
|
||||
if args.device != "cpu":
|
||||
|
|
@ -184,7 +189,7 @@ def time_fn(args, fn, inputs):
|
|||
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
|
||||
)
|
||||
for _ in bench_range:
|
||||
fn(inputs)
|
||||
fn(benchmark_inputs)
|
||||
|
||||
if args.device != "cpu":
|
||||
torch.cuda.synchronize()
|
||||
|
|
@ -244,7 +249,7 @@ def measure_fn(args, fn, inputs):
|
|||
# Measure memory usage
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
|
||||
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
|
||||
|
||||
# Flush output so memory usage is printed
|
||||
sys.stdout.flush()
|
||||
|
|
@ -255,7 +260,7 @@ def run_hf_inference(args, inputs, model):
|
|||
def get_pred_ids(inputs):
|
||||
# Inference pass with predicted token ids generation
|
||||
predicted_ids = model.generate(**inputs)
|
||||
return predicted_ids, [""]
|
||||
return predicted_ids
|
||||
|
||||
def gen_and_dec(inputs):
|
||||
# Inference pass with generation and decoding
|
||||
|
|
@ -315,7 +320,7 @@ def run_hf_inference(args, inputs, model):
|
|||
|
||||
|
||||
def run_ort_inference(args, inputs, model):
|
||||
def prepare_ort_inputs(inputs):
|
||||
def prepare_ort_inputs(inputs, warmup=False):
|
||||
# Check that all model inputs will be provided
|
||||
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
|
||||
user_inputs = set(inputs.keys())
|
||||
|
|
@ -324,6 +329,9 @@ def run_ort_inference(args, inputs, model):
|
|||
logger.error(f"The following model inputs are missing: {missing_inputs}")
|
||||
raise Exception("There are missing inputs to the model. Please add them and try again.")
|
||||
|
||||
if warmup and args.tune:
|
||||
inputs["min_length"] = inputs["max_length"]
|
||||
|
||||
# Remove unnecessary inputs from model inputs
|
||||
unnecessary_inputs = user_inputs - model_inputs
|
||||
if len(unnecessary_inputs):
|
||||
|
|
@ -352,6 +360,13 @@ def run_ort_inference(args, inputs, model):
|
|||
outputs = model.run(None, inputs)
|
||||
return outputs
|
||||
|
||||
def handle_output(output):
|
||||
if args.eos_token_id in output:
|
||||
first_end = np.where(output == args.eos_token_id)[0][0]
|
||||
return output[: first_end + 1]
|
||||
|
||||
return output
|
||||
|
||||
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
|
||||
ort_inputs = prepare_ort_inputs(inputs)
|
||||
|
||||
|
|
@ -367,7 +382,12 @@ def run_ort_inference(args, inputs, model):
|
|||
|
||||
# ORT evaluation
|
||||
logger.info("\nEvaluating ONNX Runtime...")
|
||||
time_fn(args, generate_fn, ort_inputs)
|
||||
ort_evaluate_inputs = ort_inputs
|
||||
if args.tune:
|
||||
ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True)
|
||||
ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs)
|
||||
|
||||
time_fn(args, generate_fn, ort_evaluate_inputs)
|
||||
ort_outputs = generate_fn(ort_inputs)
|
||||
if args.device != "cpu":
|
||||
ort_outputs = ort_outputs.copy_outputs_to_cpu()
|
||||
|
|
@ -378,7 +398,10 @@ def run_ort_inference(args, inputs, model):
|
|||
logger.info(f"Transcription: {ort_outputs[0][0]}")
|
||||
else:
|
||||
# convert_to_onnx model produces generated ids
|
||||
logger.info(f"Generated token length: {len(ort_outputs[0][0])} tokens")
|
||||
actual_output = handle_output(ort_outputs[0][0])
|
||||
logger.info(f"Generated token length: {len(actual_output)} tokens")
|
||||
transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
|
||||
logger.info(f"Transcription: {transcription}")
|
||||
|
||||
measure_fn(args, generate_fn, ort_inputs)
|
||||
|
||||
|
|
@ -483,6 +506,12 @@ def parse_args():
|
|||
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
|
||||
parser.add_argument("--verbose", default=False, action="store_true")
|
||||
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
|
||||
parser.add_argument(
|
||||
"--tune",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -490,13 +519,21 @@ def parse_args():
|
|||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
args.monitor_type = args.device
|
||||
# Set runtime properties
|
||||
if "ort" in args.benchmark_type:
|
||||
args.execution_provider = f"{args.device.upper()}ExecutionProvider"
|
||||
if args.execution_provider == "CUDAExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
elif args.execution_provider == "ROCMExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
args.execution_provider = (
|
||||
args.execution_provider,
|
||||
{
|
||||
"device_id": args.device_id,
|
||||
"tunable_op_enable": 1,
|
||||
"tunable_op_tuning_enable": 1 if args.tune else 0,
|
||||
},
|
||||
)
|
||||
args.device = "cuda"
|
||||
|
||||
# Check that model paths have been specified for any benchmarking with ORT
|
||||
|
|
@ -527,6 +564,7 @@ def main():
|
|||
setattr(args, "target_device", target_device) # noqa: B010
|
||||
setattr(args, "use_fp16", use_fp16) # noqa: B010
|
||||
setattr(args, "has_audio_stream", False) # noqa: B010
|
||||
setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
|
||||
|
||||
logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
|
||||
|
||||
|
|
|
|||
|
|
@ -109,6 +109,8 @@ def get_args():
|
|||
help="Number of mins to attempt the benchmark before moving on",
|
||||
)
|
||||
|
||||
parser.add_argument("--tune", default=False, action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
|
||||
|
|
@ -292,6 +294,7 @@ def main():
|
|||
ort_decoder_input_ids_cmd = (
|
||||
["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else []
|
||||
)
|
||||
ort_tune_cmd = ["--tune"] if args.tune else []
|
||||
|
||||
all_results = []
|
||||
for audio_file in os.listdir(args.audio_path):
|
||||
|
|
@ -395,31 +398,35 @@ def main():
|
|||
|
||||
# Benchmark ONNX Runtime
|
||||
if args.ort_model_path:
|
||||
benchmark_cmd = [ # noqa: RUF005
|
||||
"python3",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"ort",
|
||||
"--ort-model-path",
|
||||
args.ort_model_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
] + ort_decoder_input_ids_cmd
|
||||
benchmark_cmd = (
|
||||
[ # noqa: RUF005
|
||||
"python3",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"ort",
|
||||
"--ort-model-path",
|
||||
args.ort_model_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
]
|
||||
+ ort_decoder_input_ids_cmd
|
||||
+ ort_tune_cmd
|
||||
)
|
||||
logger.info("Benchmark ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration)
|
||||
all_results.extend(results)
|
||||
|
|
|
|||
Loading…
Reference in a new issue