[ROCm] Update whisper benchmark script (#17391)

- update whisper benchmark for ROCm EP.
This commit is contained in:
PeixuanZuo 2023-09-18 13:34:39 +08:00 committed by GitHub
parent c969237321
commit af14ae8050
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 196 additions and 89 deletions

View file

@ -8,7 +8,10 @@ import csv
import logging
import os
import random
import sys
import time
import timeit
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
@ -439,68 +442,127 @@ def get_gpu_info() -> Optional[List[Dict[str, Any]]]:
return None
def measure_memory(is_gpu, func):
class MemoryMonitor:
def __init__(self, keep_measuring=True):
self.keep_measuring = keep_measuring
class MemoryMonitor(ABC):
def __init__(self, keep_measuring=True):
self.keep_measuring = keep_measuring
def measure_cpu_usage(self):
import psutil
def measure_cpu_usage(self):
import psutil
max_usage = 0
max_usage = 0
while True:
max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
sleep(0.005) # 5ms
if not self.keep_measuring:
break
return max_usage
@abstractmethod
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
raise NotImplementedError()
class CudaMemoryMonitor(MemoryMonitor):
def __init__(self, keep_measuring=True):
super().__init__(keep_measuring)
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
from py3nvml.py3nvml import (
NVMLError,
nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetName,
nvmlInit,
nvmlShutdown,
)
max_gpu_usage = []
gpu_name = []
try:
nvmlInit()
device_count = nvmlDeviceGetCount()
if not isinstance(device_count, int):
logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
return None
max_gpu_usage = [0 for i in range(device_count)]
gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
while True:
max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
for i in range(device_count):
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
if isinstance(info, str):
logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
return None
max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
sleep(0.005) # 5ms
if not self.keep_measuring:
break
return max_usage
nvmlShutdown()
return [
{
"device_id": i,
"name": gpu_name[i],
"max_used_MB": max_gpu_usage[i],
}
for i in range(device_count)
]
except NVMLError as error:
logger.error("Error fetching GPU information using nvml: %s", error)
return None
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
from py3nvml.py3nvml import (
NVMLError,
nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetName,
nvmlInit,
nvmlShutdown,
)
max_gpu_usage = []
gpu_name = []
try:
nvmlInit()
device_count = nvmlDeviceGetCount()
if not isinstance(device_count, int):
logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
return None
class RocmMemoryMonitor(MemoryMonitor):
def __init__(self, keep_measuring=True):
super().__init__(keep_measuring)
rocm_smi_path = "/opt/rocm/libexec/rocm_smi"
if os.path.exists(rocm_smi_path):
if rocm_smi_path not in sys.path:
sys.path.append(rocm_smi_path)
try:
import rocm_smi
max_gpu_usage = [0 for i in range(device_count)]
gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
while True:
for i in range(device_count):
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
if isinstance(info, str):
logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
return None
max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
sleep(0.005) # 5ms
if not self.keep_measuring:
break
nvmlShutdown()
return [
{
"device_id": i,
"name": gpu_name[i],
"max_used_MB": max_gpu_usage[i],
}
for i in range(device_count)
]
except NVMLError as error:
logger.error("Error fetching GPU information using nvml: %s", error)
return None
self.rocm_smi = rocm_smi
self.rocm_smi.initializeRsmi()
except ImportError:
self.rocm_smi = None
monitor = MemoryMonitor(False)
def get_used_memory(self, dev):
if self.rocm_smi is None:
return -1
return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024
def measure_gpu_usage(self):
if self.rocm_smi is None:
return None
device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0
max_gpu_usage = [0 for i in range(device_count)]
gpu_name = [f"GPU{i}" for i in range(device_count)]
while True:
for i in range(device_count):
max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i))
time.sleep(0.005) # 2ms
if not self.keep_measuring:
break
return [
{
"device_id": i,
"name": gpu_name[i],
"max_used_MB": max_gpu_usage[i],
}
for i in range(device_count)
]
def measure_memory(is_gpu, func, monitor_type="cuda"):
memory_monitor_type = None
if monitor_type == "rocm":
memory_monitor_type = RocmMemoryMonitor
else:
memory_monitor_type = CudaMemoryMonitor
monitor = memory_monitor_type(False)
if is_gpu:
memory_before_test = monitor.measure_gpu_usage()
@ -508,7 +570,7 @@ def measure_memory(is_gpu, func):
return None
with ThreadPoolExecutor() as executor:
monitor = MemoryMonitor()
monitor = memory_monitor_type()
mem_thread = executor.submit(monitor.measure_gpu_usage)
try:
fn_thread = executor.submit(func)

View file

@ -11,7 +11,7 @@ import numpy as np
import psutil
import torch
import whisper
from benchmark_helper import setup_logger
from benchmark_helper import measure_memory, setup_logger
from onnxruntime_extensions import get_library_path
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
from torch.profiler import ProfilerActivity, profile, record_function
@ -19,7 +19,6 @@ from tqdm import trange
from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
import onnxruntime as ort
from onnxruntime.transformers.benchmark_helper import measure_memory
logger = logging.getLogger(__name__)
@ -123,6 +122,9 @@ def get_model(args: argparse.Namespace):
if args.verbose:
sess_options.log_verbosity_level = 1
sess_options.log_severity_level = 1
if args.tune:
ort.set_default_logger_severity(0)
ort.set_default_logger_verbosity(0)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
@ -159,6 +161,9 @@ def get_model(args: argparse.Namespace):
def time_fn(args, fn, inputs):
warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
# Warm up
warmup_range = (
range(args.warmup_runs)
@ -167,11 +172,11 @@ def time_fn(args, fn, inputs):
)
if args.verbose:
outputs = fn(inputs)
outputs = fn(warmup_inputs)
logger.info(outputs)
for _ in warmup_range:
fn(inputs)
fn(warmup_inputs)
# Benchmark
if args.device != "cpu":
@ -184,7 +189,7 @@ def time_fn(args, fn, inputs):
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
fn(inputs)
fn(benchmark_inputs)
if args.device != "cpu":
torch.cuda.synchronize()
@ -244,7 +249,7 @@ def measure_fn(args, fn, inputs):
# Measure memory usage
gc.collect()
torch.cuda.empty_cache()
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
# Flush output so memory usage is printed
sys.stdout.flush()
@ -255,7 +260,7 @@ def run_hf_inference(args, inputs, model):
def get_pred_ids(inputs):
# Inference pass with predicted token ids generation
predicted_ids = model.generate(**inputs)
return predicted_ids, [""]
return predicted_ids
def gen_and_dec(inputs):
# Inference pass with generation and decoding
@ -315,7 +320,7 @@ def run_hf_inference(args, inputs, model):
def run_ort_inference(args, inputs, model):
def prepare_ort_inputs(inputs):
def prepare_ort_inputs(inputs, warmup=False):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(inputs.keys())
@ -324,6 +329,9 @@ def run_ort_inference(args, inputs, model):
logger.error(f"The following model inputs are missing: {missing_inputs}")
raise Exception("There are missing inputs to the model. Please add them and try again.")
if warmup and args.tune:
inputs["min_length"] = inputs["max_length"]
# Remove unnecessary inputs from model inputs
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
@ -352,6 +360,13 @@ def run_ort_inference(args, inputs, model):
outputs = model.run(None, inputs)
return outputs
def handle_output(output):
if args.eos_token_id in output:
first_end = np.where(output == args.eos_token_id)[0][0]
return output[: first_end + 1]
return output
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
ort_inputs = prepare_ort_inputs(inputs)
@ -367,7 +382,12 @@ def run_ort_inference(args, inputs, model):
# ORT evaluation
logger.info("\nEvaluating ONNX Runtime...")
time_fn(args, generate_fn, ort_inputs)
ort_evaluate_inputs = ort_inputs
if args.tune:
ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True)
ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs)
time_fn(args, generate_fn, ort_evaluate_inputs)
ort_outputs = generate_fn(ort_inputs)
if args.device != "cpu":
ort_outputs = ort_outputs.copy_outputs_to_cpu()
@ -378,7 +398,10 @@ def run_ort_inference(args, inputs, model):
logger.info(f"Transcription: {ort_outputs[0][0]}")
else:
# convert_to_onnx model produces generated ids
logger.info(f"Generated token length: {len(ort_outputs[0][0])} tokens")
actual_output = handle_output(ort_outputs[0][0])
logger.info(f"Generated token length: {len(actual_output)} tokens")
transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
logger.info(f"Transcription: {transcription}")
measure_fn(args, generate_fn, ort_inputs)
@ -483,6 +506,12 @@ def parse_args():
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
parser.add_argument(
"--tune",
default=False,
action="store_true",
help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel",
)
args = parser.parse_args()
@ -490,13 +519,21 @@ def parse_args():
np.random.seed(args.seed)
torch.manual_seed(args.seed)
args.monitor_type = args.device
# Set runtime properties
if "ort" in args.benchmark_type:
args.execution_provider = f"{args.device.upper()}ExecutionProvider"
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
args.execution_provider = (
args.execution_provider,
{
"device_id": args.device_id,
"tunable_op_enable": 1,
"tunable_op_tuning_enable": 1 if args.tune else 0,
},
)
args.device = "cuda"
# Check that model paths have been specified for any benchmarking with ORT
@ -527,6 +564,7 @@ def main():
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
setattr(args, "has_audio_stream", False) # noqa: B010
setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")

View file

@ -109,6 +109,8 @@ def get_args():
help="Number of mins to attempt the benchmark before moving on",
)
parser.add_argument("--tune", default=False, action="store_true")
args = parser.parse_args()
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
@ -292,6 +294,7 @@ def main():
ort_decoder_input_ids_cmd = (
["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else []
)
ort_tune_cmd = ["--tune"] if args.tune else []
all_results = []
for audio_file in os.listdir(args.audio_path):
@ -395,31 +398,35 @@ def main():
# Benchmark ONNX Runtime
if args.ort_model_path:
benchmark_cmd = [ # noqa: RUF005
"python3",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"ort",
"--ort-model-path",
args.ort_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
] + ort_decoder_input_ids_cmd
benchmark_cmd = (
[ # noqa: RUF005
"python3",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"ort",
"--ort-model-path",
args.ort_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
]
+ ort_decoder_input_ids_cmd
+ ort_tune_cmd
)
logger.info("Benchmark ONNX Runtime")
results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration)
all_results.extend(results)