diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 288d79f624..2821f6b89b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -316,7 +316,6 @@ def export_onnx_models( use_external_data_format=use_external_data_format, per_channel=quantize_per_channel, reduce_range=quantize_reduce_range, - optimize_model=False, extra_options={"MatMulConstBOnly": True}, ) else: @@ -374,6 +373,7 @@ def main(argv=None): args.provider, ) + max_diff = 0 if args.chain_model: logger.info("Chaining model ... :") args.beam_model_output_dir = WhisperHelper.get_onnx_path( @@ -418,6 +418,7 @@ def main(argv=None): output_paths = [args.beam_model_output_dir] logger.info(f"Done! Outputs: {output_paths}") + return max_diff if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 5c41ce9bd3..3a81700a7f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -12,7 +12,6 @@ from typing import Dict, Tuple, Union import numpy as np import torch -from datasets import load_dataset from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderHelper @@ -270,6 +269,18 @@ class WhisperHelper: pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device) processor = WhisperProcessor.from_pretrained(model_name_or_path) config = WhisperConfig.from_pretrained(model_name_or_path) + + # Try to import `datasets` pip package + try: + from datasets import load_dataset + except Exception as e: + logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True) + install_cmd = "pip install datasets" + logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") + os.system(install_cmd) + + from datasets import load_dataset # noqa: F811 + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index 24536d7b43..55c5143582 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -6,6 +6,7 @@ # -------------------------------------------------------------------------- import os +import shutil import unittest import onnx @@ -19,10 +20,12 @@ if find_transformers_source() and find_transformers_source(["models", "t5"]): from benchmark_helper import Precision from convert_generation import main as run from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models + from models.whisper.convert_to_onnx import main as run_whisper else: from onnxruntime.transformers.benchmark_helper import Precision from onnxruntime.transformers.convert_generation import main as run from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models + from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper class TestBeamSearchGpt(unittest.TestCase): @@ -281,5 +284,100 @@ class TestBeamSearchT5(unittest.TestCase): ) +class TestBeamSearchWhisper(unittest.TestCase): + """Test BeamSearch for Whisper""" + + def setUp(self): + self.model_name = "openai/whisper-tiny" + self.pytorch_folder = "cache_models" + self.onnx_folder = "onnx_models" + self.decoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_decoder.onnx") + self.encoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_encoder_decoder_init.onnx") + self.beam_search_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_beamsearch.onnx") + self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers() + + self.base_arguments = [ + "-m", + self.model_name, + "--output", + self.onnx_folder, + "--use_external_data_format", + ] + self.fp32_cpu_arguments = [ + "--precision", + "fp32", + "--optimize_onnx", + ] + self.fp16_cuda_arguments = [ + "--precision", + "fp16", + "--provider", + "cuda", + "--optimize_onnx", + "--use_gpu", + ] + self.int8_cpu_arguments = [ + "--precision", + "int8", + "--quantize_embedding_layer", + ] + + def tearDown(self): + pytorch_dir = os.path.join(".", self.pytorch_folder) + if os.path.exists(pytorch_dir): + shutil.rmtree(pytorch_dir) + onnx_dir = os.path.join(".", self.onnx_folder) + if os.path.exists(onnx_dir): + shutil.rmtree(onnx_dir) + + def remove_onnx_files(self): + if os.path.exists(self.beam_search_onnx_path): + os.remove(self.beam_search_onnx_path) + os.remove(self.beam_search_onnx_path + ".data") + + if os.path.exists(self.decoder_onnx_path): + os.remove(self.decoder_onnx_path) + os.remove(self.decoder_onnx_path + ".data") + + if os.path.exists(self.encoder_onnx_path): + os.remove(self.encoder_onnx_path) + os.remove(self.encoder_onnx_path + ".data") + + def run_export(self, arguments): + max_diff = run_whisper(arguments) + self.assertTrue(os.path.exists(self.beam_search_onnx_path), "Whisper model was not exported") + self.remove_onnx_files() + self.assertTrue(max_diff == 0, f"ORT and PyTorch have a parity mismatch of {max_diff}") + + def run_configs(self, optional_arguments): + # FP32 CPU + arguments = self.base_arguments + self.fp32_cpu_arguments + optional_arguments + self.run_export(arguments) + + if self.enable_cuda: + # FP16 CUDA + arguments = self.base_arguments + self.fp16_cuda_arguments + optional_arguments + self.run_export(arguments) + + # INT8 CPU + arguments = self.base_arguments + self.int8_cpu_arguments + optional_arguments + self.run_export(arguments) + + @pytest.mark.slow + def test_required_args(self): + optional_args = [] + self.run_configs(optional_args) + + @pytest.mark.slow + def test_forced_decoder_ids(self): + decoder_input_ids = ["--use_forced_decoder_ids"] + self.run_configs(decoder_input_ids) + + @pytest.mark.slow + def test_logits_processor(self): + logits_processor = ["--use_logits_processor"] + self.run_configs(logits_processor) + + if __name__ == "__main__": unittest.main()