mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Add Whisper export with beam search test cases (#17228)
### Description This PR adds test cases for the custom export of [Whisper with beam search](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/whisper). ### Motivation and Context This PR checks that Whisper can be exported and runs with parity.
This commit is contained in:
parent
9445539e2c
commit
4bea5ec513
3 changed files with 112 additions and 2 deletions
|
|
@ -316,7 +316,6 @@ def export_onnx_models(
|
|||
use_external_data_format=use_external_data_format,
|
||||
per_channel=quantize_per_channel,
|
||||
reduce_range=quantize_reduce_range,
|
||||
optimize_model=False,
|
||||
extra_options={"MatMulConstBOnly": True},
|
||||
)
|
||||
else:
|
||||
|
|
@ -374,6 +373,7 @@ def main(argv=None):
|
|||
args.provider,
|
||||
)
|
||||
|
||||
max_diff = 0
|
||||
if args.chain_model:
|
||||
logger.info("Chaining model ... :")
|
||||
args.beam_model_output_dir = WhisperHelper.get_onnx_path(
|
||||
|
|
@ -418,6 +418,7 @@ def main(argv=None):
|
|||
output_paths = [args.beam_model_output_dir]
|
||||
|
||||
logger.info(f"Done! Outputs: {output_paths}")
|
||||
return max_diff
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from typing import Dict, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
|
||||
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
|
||||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
|
||||
|
|
@ -270,6 +269,18 @@ class WhisperHelper:
|
|||
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device)
|
||||
processor = WhisperProcessor.from_pretrained(model_name_or_path)
|
||||
config = WhisperConfig.from_pretrained(model_name_or_path)
|
||||
|
||||
# Try to import `datasets` pip package
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True)
|
||||
install_cmd = "pip install datasets"
|
||||
logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
|
||||
os.system(install_cmd)
|
||||
|
||||
from datasets import load_dataset # noqa: F811
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import onnx
|
||||
|
|
@ -19,10 +20,12 @@ if find_transformers_source() and find_transformers_source(["models", "t5"]):
|
|||
from benchmark_helper import Precision
|
||||
from convert_generation import main as run
|
||||
from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
|
||||
from models.whisper.convert_to_onnx import main as run_whisper
|
||||
else:
|
||||
from onnxruntime.transformers.benchmark_helper import Precision
|
||||
from onnxruntime.transformers.convert_generation import main as run
|
||||
from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
|
||||
from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper
|
||||
|
||||
|
||||
class TestBeamSearchGpt(unittest.TestCase):
|
||||
|
|
@ -281,5 +284,100 @@ class TestBeamSearchT5(unittest.TestCase):
|
|||
)
|
||||
|
||||
|
||||
class TestBeamSearchWhisper(unittest.TestCase):
|
||||
"""Test BeamSearch for Whisper"""
|
||||
|
||||
def setUp(self):
|
||||
self.model_name = "openai/whisper-tiny"
|
||||
self.pytorch_folder = "cache_models"
|
||||
self.onnx_folder = "onnx_models"
|
||||
self.decoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_decoder.onnx")
|
||||
self.encoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_encoder_decoder_init.onnx")
|
||||
self.beam_search_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_beamsearch.onnx")
|
||||
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()
|
||||
|
||||
self.base_arguments = [
|
||||
"-m",
|
||||
self.model_name,
|
||||
"--output",
|
||||
self.onnx_folder,
|
||||
"--use_external_data_format",
|
||||
]
|
||||
self.fp32_cpu_arguments = [
|
||||
"--precision",
|
||||
"fp32",
|
||||
"--optimize_onnx",
|
||||
]
|
||||
self.fp16_cuda_arguments = [
|
||||
"--precision",
|
||||
"fp16",
|
||||
"--provider",
|
||||
"cuda",
|
||||
"--optimize_onnx",
|
||||
"--use_gpu",
|
||||
]
|
||||
self.int8_cpu_arguments = [
|
||||
"--precision",
|
||||
"int8",
|
||||
"--quantize_embedding_layer",
|
||||
]
|
||||
|
||||
def tearDown(self):
|
||||
pytorch_dir = os.path.join(".", self.pytorch_folder)
|
||||
if os.path.exists(pytorch_dir):
|
||||
shutil.rmtree(pytorch_dir)
|
||||
onnx_dir = os.path.join(".", self.onnx_folder)
|
||||
if os.path.exists(onnx_dir):
|
||||
shutil.rmtree(onnx_dir)
|
||||
|
||||
def remove_onnx_files(self):
|
||||
if os.path.exists(self.beam_search_onnx_path):
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path + ".data")
|
||||
|
||||
if os.path.exists(self.decoder_onnx_path):
|
||||
os.remove(self.decoder_onnx_path)
|
||||
os.remove(self.decoder_onnx_path + ".data")
|
||||
|
||||
if os.path.exists(self.encoder_onnx_path):
|
||||
os.remove(self.encoder_onnx_path)
|
||||
os.remove(self.encoder_onnx_path + ".data")
|
||||
|
||||
def run_export(self, arguments):
|
||||
max_diff = run_whisper(arguments)
|
||||
self.assertTrue(os.path.exists(self.beam_search_onnx_path), "Whisper model was not exported")
|
||||
self.remove_onnx_files()
|
||||
self.assertTrue(max_diff == 0, f"ORT and PyTorch have a parity mismatch of {max_diff}")
|
||||
|
||||
def run_configs(self, optional_arguments):
|
||||
# FP32 CPU
|
||||
arguments = self.base_arguments + self.fp32_cpu_arguments + optional_arguments
|
||||
self.run_export(arguments)
|
||||
|
||||
if self.enable_cuda:
|
||||
# FP16 CUDA
|
||||
arguments = self.base_arguments + self.fp16_cuda_arguments + optional_arguments
|
||||
self.run_export(arguments)
|
||||
|
||||
# INT8 CPU
|
||||
arguments = self.base_arguments + self.int8_cpu_arguments + optional_arguments
|
||||
self.run_export(arguments)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_required_args(self):
|
||||
optional_args = []
|
||||
self.run_configs(optional_args)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_forced_decoder_ids(self):
|
||||
decoder_input_ids = ["--use_forced_decoder_ids"]
|
||||
self.run_configs(decoder_input_ids)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_logits_processor(self):
|
||||
logits_processor = ["--use_logits_processor"]
|
||||
self.run_configs(logits_processor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue