mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Stabilize Whisper export with beam search (#16297)
### Description This PR stabilizes the Whisper export with beam search by adding the following: - Remove unused ONNX models and extra folders generated during the export process - Specify the Whisper with beam search model's IR version for E2E integration - Parity check for Whisper with beam search model between PyTorch and ORT - Remove previously exported Whisper with beam search model before saving newly exported model ### Motivation and Context - Removing the unused ONNX models and extra folders frees up disk space after exporting and makes it easier to copy and move the output folder to other environments. - Specifying the IR version fixes an issue with generating the ONNX E2E model - Adding a parity check helps detect runtime issues during the export process - Removing the previously exported Whisper with beam search model prevents the data file size from doubling when the newly exported model is saved with the same filename
This commit is contained in:
parent
dd660c054e
commit
3f7f90aed0
6 changed files with 116 additions and 36 deletions
|
|
@ -26,8 +26,8 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
.InputMemoryType(OrtMemTypeCPUInput, 6) // 'repetition_penalty' needs to be on CPU
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<MLFloat16>()}),
|
||||
|
|
|
|||
|
|
@ -332,12 +332,7 @@ def export_onnx_models(
|
|||
use_gpu=use_gpu,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
max_diff = WhisperHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
||||
logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
|
||||
if max_diff > 1e-4:
|
||||
logger.warning("PyTorch and OnnxRuntime results are NOT close")
|
||||
assert ort_session is not None
|
||||
|
||||
output_paths.append(output_path)
|
||||
|
||||
|
|
@ -398,6 +393,32 @@ def main(argv=None):
|
|||
chain_model(args)
|
||||
output_paths.append(args.beam_model_output_dir)
|
||||
|
||||
# Check chained model
|
||||
ort_session = create_onnxruntime_session(
|
||||
args.beam_model_output_dir,
|
||||
use_gpu=args.use_gpu,
|
||||
provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if args.use_gpu else ["CPUExecutionProvider"],
|
||||
)
|
||||
device = torch.device("cuda:0" if args.use_gpu else "cpu")
|
||||
|
||||
# Wrap parity check in try-except to allow export to continue in case this produces an error
|
||||
try:
|
||||
with torch.no_grad():
|
||||
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device)
|
||||
logger.info(f"Max difference between PyTorch and ONNX Runtime token ids = {max_diff}")
|
||||
if max_diff > 1e-4:
|
||||
logger.warning("PyTorch and ONNX Runtime results are NOT close")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# Remove extra ONNX models saved in output directory
|
||||
for fle in os.listdir(output_dir):
|
||||
if "_beamsearch" not in fle:
|
||||
os.remove(os.path.join(output_dir, fle))
|
||||
output_paths = [args.beam_model_output_dir]
|
||||
|
||||
logger.info(f"Done! Outputs: {output_paths}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
@ -12,6 +13,8 @@ from convert_generation import ( # noqa: E402
|
|||
update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chain_model(args):
|
||||
# Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op
|
||||
|
|
@ -33,15 +36,10 @@ def chain_model(args):
|
|||
"repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "input_features",
|
||||
"vocab_mask" if args.use_prefix_vocab_mask else "",
|
||||
"prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
|
||||
"",
|
||||
"", # attention mask
|
||||
"decoder_input_ids" if args.use_forced_decoder_ids else "",
|
||||
"logits_processor" if args.use_logits_processor else "",
|
||||
]
|
||||
if args.use_forced_decoder_ids:
|
||||
beam_inputs.append("decoder_input_ids")
|
||||
else:
|
||||
beam_inputs.append("")
|
||||
|
||||
if args.use_logits_processor:
|
||||
beam_inputs.append("logits_processor")
|
||||
beam_outputs = ["sequences"]
|
||||
|
||||
input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None
|
||||
|
|
@ -128,9 +126,9 @@ def chain_model(args):
|
|||
|
||||
if hasattr(args, "use_gpu") and args.use_gpu:
|
||||
if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
|
||||
print("*****Updated whisper decoder subgraph successfully!!!*****")
|
||||
logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!")
|
||||
else:
|
||||
print("*****DecoderMaskedMultiHeadAttention is not applied to whisper decoder*****")
|
||||
logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph")
|
||||
|
||||
# Initializers/opsets
|
||||
# Delete shared data between decoder/encoder and move to larger graph initializers
|
||||
|
|
@ -150,8 +148,21 @@ def chain_model(args):
|
|||
else [node]
|
||||
)
|
||||
beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers)
|
||||
beam_model = helper.make_model(beam_graph, producer_name="onnxruntime.transformers", opset_imports=opset_import)
|
||||
assert decoder_model.ir_version == encoder_model.ir_version
|
||||
logger.info(f"Using IR version {decoder_model.ir_version} for chained model")
|
||||
|
||||
# Set IR version of chained model to IR version of subgraphs in order to generate a working E2E model
|
||||
beam_model = helper.make_model_gen_version(
|
||||
beam_graph,
|
||||
producer_name="onnxruntime.transformers",
|
||||
opset_imports=opset_import,
|
||||
ir_version=decoder_model.ir_version,
|
||||
)
|
||||
|
||||
if os.path.isfile(args.beam_model_output_dir):
|
||||
logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}")
|
||||
os.remove(args.beam_model_output_dir)
|
||||
os.remove(args.beam_model_output_dir + ".data")
|
||||
onnx.save(
|
||||
beam_model,
|
||||
args.beam_model_output_dir,
|
||||
|
|
|
|||
|
|
@ -28,20 +28,18 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class WhisperDecoderInit(torch.nn.Module):
|
||||
"""A Whisper decoder with LM head to create initial past key values.
|
||||
"""A Whisper decoder to create initial past key values.
|
||||
This model is only called once during starting decoding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: torch.nn.Module,
|
||||
lm_head: torch.nn.Module,
|
||||
config: WhisperConfig,
|
||||
decoder_start_token_id: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.lm_head = lm_head
|
||||
self.config = config
|
||||
self.decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
||||
|
|
@ -70,12 +68,11 @@ class WhisperDecoderInit(torch.nn.Module):
|
|||
|
||||
|
||||
class WhisperDecoder(torch.nn.Module):
|
||||
"""A Whisper decoder with LM head and past key values"""
|
||||
"""A Whisper decoder with past key values"""
|
||||
|
||||
def __init__(self, decoder, lm_head, config):
|
||||
def __init__(self, decoder, config):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.lm_head = lm_head
|
||||
self.config = config
|
||||
|
||||
def forward(self, decoder_input_ids, *past):
|
||||
|
|
|
|||
|
|
@ -35,14 +35,13 @@ class WhisperEncoderDecoderInit(torch.nn.Module):
|
|||
self,
|
||||
encoder: torch.nn.Module,
|
||||
decoder: torch.nn.Module,
|
||||
lm_head: torch.nn.Module,
|
||||
config: WhisperConfig,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.whisper_encoder = WhisperEncoder(encoder, config)
|
||||
self.whisper_decoder_init = WhisperDecoderInit(decoder, lm_head, config, decoder_start_token_id)
|
||||
self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import WhisperForConditionalGeneration
|
||||
from datasets import load_dataset
|
||||
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
|
||||
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
|
||||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
|
||||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
|
||||
|
|
@ -62,7 +64,7 @@ class WhisperHelper:
|
|||
if os.path.isdir(model_name_or_path):
|
||||
model_name = Path(model_name_or_path).parts[-1]
|
||||
else:
|
||||
model_name.split("/")[-1]
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
model_name += suffix
|
||||
|
||||
|
|
@ -91,14 +93,13 @@ class WhisperHelper:
|
|||
if state_dict_path:
|
||||
model.load_state_dict(torch.load(state_dict_path), strict=False)
|
||||
|
||||
decoder = WhisperDecoder(model, None, model.config)
|
||||
decoder = WhisperDecoder(model, model.config)
|
||||
decoder.eval().to(device)
|
||||
|
||||
if merge_encoder_and_decoder_init:
|
||||
encoder_decoder_init = WhisperEncoderDecoderInit(
|
||||
model,
|
||||
model,
|
||||
None,
|
||||
model.config,
|
||||
decoder_start_token_id=None,
|
||||
)
|
||||
|
|
@ -106,7 +107,7 @@ class WhisperHelper:
|
|||
else:
|
||||
encoder = WhisperEncoder(model.model.encoder, model.config)
|
||||
encoder.eval().to(device)
|
||||
decoder_init = WhisperDecoderInit(model.decoder, None, model.config)
|
||||
decoder_init = WhisperDecoderInit(model.decoder, model.config)
|
||||
decoder_init.eval().to(device)
|
||||
return {
|
||||
"encoder": encoder,
|
||||
|
|
@ -261,11 +262,62 @@ class WhisperHelper:
|
|||
|
||||
@staticmethod
|
||||
def verify_onnx(
|
||||
model: Union[WhisperEncoder, WhisperDecoder, WhisperDecoderInit, WhisperEncoderDecoderInit],
|
||||
model_name_or_path: str,
|
||||
ort_session: InferenceSession,
|
||||
device: torch.device,
|
||||
use_int32_inputs: bool,
|
||||
):
|
||||
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
||||
# Not implemented for Whisper currently
|
||||
return 0
|
||||
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
|
||||
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device)
|
||||
processor = WhisperProcessor.from_pretrained(model_name_or_path)
|
||||
config = WhisperConfig.from_pretrained(model_name_or_path)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
|
||||
|
||||
batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1
|
||||
length_penalty, repetition_penalty = 1.0, 1.0
|
||||
inputs = {
|
||||
"input_features": input_features.to(device),
|
||||
"max_length": max_length,
|
||||
"min_length": min_length,
|
||||
"num_beams": num_beams,
|
||||
"num_return_sequences": num_return_sequences,
|
||||
"length_penalty": length_penalty,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"early_stopping": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
|
||||
|
||||
del inputs["early_stopping"]
|
||||
del inputs["use_cache"]
|
||||
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
|
||||
ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
|
||||
ort_to_np = {
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
}
|
||||
|
||||
for name, dtype in zip(ort_names, ort_dtypes):
|
||||
if name == "input_features":
|
||||
inputs[name] = inputs[name].detach().cpu().numpy()
|
||||
elif name == "vocab_mask":
|
||||
inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype])
|
||||
elif name == "prefix_vocab_mask":
|
||||
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
|
||||
elif name == "decoder_input_ids":
|
||||
inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype])
|
||||
elif name == "logits_processor":
|
||||
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
|
||||
else:
|
||||
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
|
||||
ort_outputs = ort_session.run(None, inputs)[0][0]
|
||||
|
||||
if pt_outputs.shape != ort_outputs.shape:
|
||||
logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape")
|
||||
|
||||
diff = pt_outputs - ort_outputs
|
||||
return max(diff.min(), diff.max(), key=abs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue