Add Whisper export with beam search test cases (#17228)

### Description
This PR adds test cases for the custom export of [Whisper with beam
search](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/whisper).



### Motivation and Context
This PR checks that Whisper can be exported and runs with parity.
This commit is contained in:
kunal-vaishnavi 2023-08-20 00:58:08 -07:00 committed by GitHub
parent 9445539e2c
commit 4bea5ec513
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 112 additions and 2 deletions

View file

@ -316,7 +316,6 @@ def export_onnx_models(
use_external_data_format=use_external_data_format,
per_channel=quantize_per_channel,
reduce_range=quantize_reduce_range,
optimize_model=False,
extra_options={"MatMulConstBOnly": True},
)
else:
@ -374,6 +373,7 @@ def main(argv=None):
args.provider,
)
max_diff = 0
if args.chain_model:
logger.info("Chaining model ... :")
args.beam_model_output_dir = WhisperHelper.get_onnx_path(
@ -418,6 +418,7 @@ def main(argv=None):
output_paths = [args.beam_model_output_dir]
logger.info(f"Done! Outputs: {output_paths}")
return max_diff
if __name__ == "__main__":

View file

@ -12,7 +12,6 @@ from typing import Dict, Tuple, Union
import numpy as np
import torch
from datasets import load_dataset
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
@ -270,6 +269,18 @@ class WhisperHelper:
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path)
config = WhisperConfig.from_pretrained(model_name_or_path)
# Try to import `datasets` pip package
try:
from datasets import load_dataset
except Exception as e:
logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True)
install_cmd = "pip install datasets"
logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
os.system(install_cmd)
from datasets import load_dataset # noqa: F811
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features

View file

@ -6,6 +6,7 @@
# --------------------------------------------------------------------------
import os
import shutil
import unittest
import onnx
@ -19,10 +20,12 @@ if find_transformers_source() and find_transformers_source(["models", "t5"]):
from benchmark_helper import Precision
from convert_generation import main as run
from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
from models.whisper.convert_to_onnx import main as run_whisper
else:
from onnxruntime.transformers.benchmark_helper import Precision
from onnxruntime.transformers.convert_generation import main as run
from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper
class TestBeamSearchGpt(unittest.TestCase):
@ -281,5 +284,100 @@ class TestBeamSearchT5(unittest.TestCase):
)
class TestBeamSearchWhisper(unittest.TestCase):
"""Test BeamSearch for Whisper"""
def setUp(self):
self.model_name = "openai/whisper-tiny"
self.pytorch_folder = "cache_models"
self.onnx_folder = "onnx_models"
self.decoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_decoder.onnx")
self.encoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_encoder_decoder_init.onnx")
self.beam_search_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_beamsearch.onnx")
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()
self.base_arguments = [
"-m",
self.model_name,
"--output",
self.onnx_folder,
"--use_external_data_format",
]
self.fp32_cpu_arguments = [
"--precision",
"fp32",
"--optimize_onnx",
]
self.fp16_cuda_arguments = [
"--precision",
"fp16",
"--provider",
"cuda",
"--optimize_onnx",
"--use_gpu",
]
self.int8_cpu_arguments = [
"--precision",
"int8",
"--quantize_embedding_layer",
]
def tearDown(self):
pytorch_dir = os.path.join(".", self.pytorch_folder)
if os.path.exists(pytorch_dir):
shutil.rmtree(pytorch_dir)
onnx_dir = os.path.join(".", self.onnx_folder)
if os.path.exists(onnx_dir):
shutil.rmtree(onnx_dir)
def remove_onnx_files(self):
if os.path.exists(self.beam_search_onnx_path):
os.remove(self.beam_search_onnx_path)
os.remove(self.beam_search_onnx_path + ".data")
if os.path.exists(self.decoder_onnx_path):
os.remove(self.decoder_onnx_path)
os.remove(self.decoder_onnx_path + ".data")
if os.path.exists(self.encoder_onnx_path):
os.remove(self.encoder_onnx_path)
os.remove(self.encoder_onnx_path + ".data")
def run_export(self, arguments):
max_diff = run_whisper(arguments)
self.assertTrue(os.path.exists(self.beam_search_onnx_path), "Whisper model was not exported")
self.remove_onnx_files()
self.assertTrue(max_diff == 0, f"ORT and PyTorch have a parity mismatch of {max_diff}")
def run_configs(self, optional_arguments):
# FP32 CPU
arguments = self.base_arguments + self.fp32_cpu_arguments + optional_arguments
self.run_export(arguments)
if self.enable_cuda:
# FP16 CUDA
arguments = self.base_arguments + self.fp16_cuda_arguments + optional_arguments
self.run_export(arguments)
# INT8 CPU
arguments = self.base_arguments + self.int8_cpu_arguments + optional_arguments
self.run_export(arguments)
@pytest.mark.slow
def test_required_args(self):
optional_args = []
self.run_configs(optional_args)
@pytest.mark.slow
def test_forced_decoder_ids(self):
decoder_input_ids = ["--use_forced_decoder_ids"]
self.run_configs(decoder_input_ids)
@pytest.mark.slow
def test_logits_processor(self):
logits_processor = ["--use_logits_processor"]
self.run_configs(logits_processor)
if __name__ == "__main__":
unittest.main()