From 3f7f90aed02a0d8d99c48fa89201759477794b8d Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 16 Jun 2023 18:56:52 -0700 Subject: [PATCH] Stabilize Whisper export with beam search (#16297) ### Description This PR stabilizes the Whisper export with beam search by adding the following: - Remove unused ONNX models and extra folders generated during the export process - Specify the Whisper with beam search model's IR version for E2E integration - Parity check for Whisper with beam search model between PyTorch and ORT - Remove previously exported Whisper with beam search model before saving newly exported model ### Motivation and Context - Removing the unused ONNX models and extra folders frees up disk space after exporting and makes it easier to copy and move the output folder to other environments. - Specifying the IR version fixes an issue with generating the ONNX E2E model - Adding a parity check helps detect runtime issues during the export process - Removing the previously exported Whisper with beam search model prevents the data file size from doubling when the newly exported model is saved with the same filename --- .../cuda/transformers/beam_search.cc | 2 +- .../models/whisper/convert_to_onnx.py | 33 +++++++-- .../models/whisper/whisper_chain.py | 33 ++++++--- .../models/whisper/whisper_decoder.py | 9 +-- .../whisper/whisper_encoder_decoder_init.py | 3 +- .../models/whisper/whisper_helper.py | 72 ++++++++++++++++--- 6 files changed, 116 insertions(+), 36 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 759102b099..a3e86e34ae 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -26,8 +26,8 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 6) // 'repetition_penalty' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU - .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 427a1df3c9..68e81408de 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -332,12 +332,7 @@ def export_onnx_models( use_gpu=use_gpu, provider=provider, ) - - with torch.no_grad(): - max_diff = WhisperHelper.verify_onnx(model, ort_session, device, use_int32_inputs) - logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}") - if max_diff > 1e-4: - logger.warning("PyTorch and OnnxRuntime results are NOT close") + assert ort_session is not None output_paths.append(output_path) @@ -398,6 +393,32 @@ def main(argv=None): chain_model(args) output_paths.append(args.beam_model_output_dir) + # Check chained model + ort_session = create_onnxruntime_session( + args.beam_model_output_dir, + use_gpu=args.use_gpu, + provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if args.use_gpu else ["CPUExecutionProvider"], + ) + device = torch.device("cuda:0" if args.use_gpu else "cpu") + + # Wrap parity check in try-except to allow export to continue in case this produces an error + try: + with torch.no_grad(): + max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device) + logger.info(f"Max difference between PyTorch and ONNX Runtime token ids = {max_diff}") + if max_diff > 1e-4: + logger.warning("PyTorch and ONNX Runtime results are NOT close") + except Exception as e: + logger.warning( + f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True + ) + + # Remove extra ONNX models saved in output directory + for fle in os.listdir(output_dir): + if "_beamsearch" not in fle: + os.remove(os.path.join(output_dir, fle)) + output_paths = [args.beam_model_output_dir] + logger.info(f"Done! Outputs: {output_paths}") diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index f0aa419c4c..ca7da098f2 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -1,3 +1,4 @@ +import logging import os import sys @@ -12,6 +13,8 @@ from convert_generation import ( # noqa: E402 update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, ) +logger = logging.getLogger(__name__) + def chain_model(args): # Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op @@ -33,15 +36,10 @@ def chain_model(args): "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "input_features", "vocab_mask" if args.use_prefix_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", - "", + "", # attention mask + "decoder_input_ids" if args.use_forced_decoder_ids else "", + "logits_processor" if args.use_logits_processor else "", ] - if args.use_forced_decoder_ids: - beam_inputs.append("decoder_input_ids") - else: - beam_inputs.append("") - - if args.use_logits_processor: - beam_inputs.append("logits_processor") beam_outputs = ["sequences"] input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None @@ -128,9 +126,9 @@ def chain_model(args): if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): - print("*****Updated whisper decoder subgraph successfully!!!*****") + logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") else: - print("*****DecoderMaskedMultiHeadAttention is not applied to whisper decoder*****") + logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph") # Initializers/opsets # Delete shared data between decoder/encoder and move to larger graph initializers @@ -150,8 +148,21 @@ def chain_model(args): else [node] ) beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) - beam_model = helper.make_model(beam_graph, producer_name="onnxruntime.transformers", opset_imports=opset_import) + assert decoder_model.ir_version == encoder_model.ir_version + logger.info(f"Using IR version {decoder_model.ir_version} for chained model") + # Set IR version of chained model to IR version of subgraphs in order to generate a working E2E model + beam_model = helper.make_model_gen_version( + beam_graph, + producer_name="onnxruntime.transformers", + opset_imports=opset_import, + ir_version=decoder_model.ir_version, + ) + + if os.path.isfile(args.beam_model_output_dir): + logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}") + os.remove(args.beam_model_output_dir) + os.remove(args.beam_model_output_dir + ".data") onnx.save( beam_model, args.beam_model_output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 21073a6a2c..7d6d038ffa 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -28,20 +28,18 @@ logger = logging.getLogger(__name__) class WhisperDecoderInit(torch.nn.Module): - """A Whisper decoder with LM head to create initial past key values. + """A Whisper decoder to create initial past key values. This model is only called once during starting decoding. """ def __init__( self, decoder: torch.nn.Module, - lm_head: torch.nn.Module, config: WhisperConfig, decoder_start_token_id: int = None, ): super().__init__() self.decoder = decoder - self.lm_head = lm_head self.config = config self.decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id @@ -70,12 +68,11 @@ class WhisperDecoderInit(torch.nn.Module): class WhisperDecoder(torch.nn.Module): - """A Whisper decoder with LM head and past key values""" + """A Whisper decoder with past key values""" - def __init__(self, decoder, lm_head, config): + def __init__(self, decoder, config): super().__init__() self.decoder = decoder - self.lm_head = lm_head self.config = config def forward(self, decoder_input_ids, *past): diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 8ea507b4c6..094ddcebe3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -35,14 +35,13 @@ class WhisperEncoderDecoderInit(torch.nn.Module): self, encoder: torch.nn.Module, decoder: torch.nn.Module, - lm_head: torch.nn.Module, config: WhisperConfig, decoder_start_token_id: Optional[int] = None, ): super().__init__() self.config = config self.whisper_encoder = WhisperEncoder(encoder, config) - self.whisper_decoder_init = WhisperDecoderInit(decoder, lm_head, config, decoder_start_token_id) + self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id) def forward( self, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 2dbf428a89..3037ded659 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -10,8 +10,10 @@ import sys from pathlib import Path from typing import Dict, Tuple, Union +import numpy as np import torch -from transformers import WhisperForConditionalGeneration +from datasets import load_dataset +from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderHelper from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper @@ -62,7 +64,7 @@ class WhisperHelper: if os.path.isdir(model_name_or_path): model_name = Path(model_name_or_path).parts[-1] else: - model_name.split("/")[-1] + model_name = model_name.split("/")[-1] model_name += suffix @@ -91,14 +93,13 @@ class WhisperHelper: if state_dict_path: model.load_state_dict(torch.load(state_dict_path), strict=False) - decoder = WhisperDecoder(model, None, model.config) + decoder = WhisperDecoder(model, model.config) decoder.eval().to(device) if merge_encoder_and_decoder_init: encoder_decoder_init = WhisperEncoderDecoderInit( model, model, - None, model.config, decoder_start_token_id=None, ) @@ -106,7 +107,7 @@ class WhisperHelper: else: encoder = WhisperEncoder(model.model.encoder, model.config) encoder.eval().to(device) - decoder_init = WhisperDecoderInit(model.decoder, None, model.config) + decoder_init = WhisperDecoderInit(model.decoder, model.config) decoder_init.eval().to(device) return { "encoder": encoder, @@ -261,11 +262,62 @@ class WhisperHelper: @staticmethod def verify_onnx( - model: Union[WhisperEncoder, WhisperDecoder, WhisperDecoderInit, WhisperEncoderDecoderInit], + model_name_or_path: str, ort_session: InferenceSession, device: torch.device, - use_int32_inputs: bool, ): - """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" - # Not implemented for Whisper currently - return 0 + """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" + pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device) + processor = WhisperProcessor.from_pretrained(model_name_or_path) + config = WhisperConfig.from_pretrained(model_name_or_path) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features + + batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 + length_penalty, repetition_penalty = 1.0, 1.0 + inputs = { + "input_features": input_features.to(device), + "max_length": max_length, + "min_length": min_length, + "num_beams": num_beams, + "num_return_sequences": num_return_sequences, + "length_penalty": length_penalty, + "repetition_penalty": repetition_penalty, + "early_stopping": True, + "use_cache": True, + } + pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + + del inputs["early_stopping"] + del inputs["use_cache"] + ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) + ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) + ort_to_np = { + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + "tensor(int64)": np.int64, + "tensor(int32)": np.int32, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + } + + for name, dtype in zip(ort_names, ort_dtypes): + if name == "input_features": + inputs[name] = inputs[name].detach().cpu().numpy() + elif name == "vocab_mask": + inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype]) + elif name == "prefix_vocab_mask": + inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) + elif name == "decoder_input_ids": + inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype]) + elif name == "logits_processor": + inputs[name] = np.array([1], dtype=ort_to_np[dtype]) + else: + inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) + ort_outputs = ort_session.run(None, inputs)[0][0] + + if pt_outputs.shape != ort_outputs.shape: + logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + + diff = pt_outputs - ort_outputs + return max(diff.min(), diff.max(), key=abs)