Stabilize Whisper export with beam search (#16297)

### Description
This PR stabilizes the Whisper export with beam search by adding the
following:
- Remove unused ONNX models and extra folders generated during the
export process
- Specify the Whisper with beam search model's IR version for E2E
integration
- Parity check for Whisper with beam search model between PyTorch and
ORT
- Remove previously exported Whisper with beam search model before
saving newly exported model


### Motivation and Context
- Removing the unused ONNX models and extra folders frees up disk space
after exporting and makes it easier to copy and move the output folder
to other environments.
- Specifying the IR version fixes an issue with generating the ONNX E2E
model
- Adding a parity check helps detect runtime issues during the export
process
- Removing the previously exported Whisper with beam search model
prevents the data file size from doubling when the newly exported model
is saved with the same filename
This commit is contained in:
kunal-vaishnavi 2023-06-16 18:56:52 -07:00 committed by GitHub
parent dd660c054e
commit 3f7f90aed0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 116 additions and 36 deletions

View file

@ -26,8 +26,8 @@ ONNX_OPERATOR_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPUInput, 6) // 'repetition_penalty' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU
.OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU
.InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU
.OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU
.OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<MLFloat16>()}),

View file

@ -332,12 +332,7 @@ def export_onnx_models(
use_gpu=use_gpu,
provider=provider,
)
with torch.no_grad():
max_diff = WhisperHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
if max_diff > 1e-4:
logger.warning("PyTorch and OnnxRuntime results are NOT close")
assert ort_session is not None
output_paths.append(output_path)
@ -398,6 +393,32 @@ def main(argv=None):
chain_model(args)
output_paths.append(args.beam_model_output_dir)
# Check chained model
ort_session = create_onnxruntime_session(
args.beam_model_output_dir,
use_gpu=args.use_gpu,
provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if args.use_gpu else ["CPUExecutionProvider"],
)
device = torch.device("cuda:0" if args.use_gpu else "cpu")
# Wrap parity check in try-except to allow export to continue in case this produces an error
try:
with torch.no_grad():
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device)
logger.info(f"Max difference between PyTorch and ONNX Runtime token ids = {max_diff}")
if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close")
except Exception as e:
logger.warning(
f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True
)
# Remove extra ONNX models saved in output directory
for fle in os.listdir(output_dir):
if "_beamsearch" not in fle:
os.remove(os.path.join(output_dir, fle))
output_paths = [args.beam_model_output_dir]
logger.info(f"Done! Outputs: {output_paths}")

View file

@ -1,3 +1,4 @@
import logging
import os
import sys
@ -12,6 +13,8 @@ from convert_generation import ( # noqa: E402
update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
)
logger = logging.getLogger(__name__)
def chain_model(args):
# Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op
@ -33,15 +36,10 @@ def chain_model(args):
"repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "input_features",
"vocab_mask" if args.use_prefix_vocab_mask else "",
"prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
"",
"", # attention mask
"decoder_input_ids" if args.use_forced_decoder_ids else "",
"logits_processor" if args.use_logits_processor else "",
]
if args.use_forced_decoder_ids:
beam_inputs.append("decoder_input_ids")
else:
beam_inputs.append("")
if args.use_logits_processor:
beam_inputs.append("logits_processor")
beam_outputs = ["sequences"]
input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None
@ -128,9 +126,9 @@ def chain_model(args):
if hasattr(args, "use_gpu") and args.use_gpu:
if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
print("*****Updated whisper decoder subgraph successfully!!!*****")
logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!")
else:
print("*****DecoderMaskedMultiHeadAttention is not applied to whisper decoder*****")
logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph")
# Initializers/opsets
# Delete shared data between decoder/encoder and move to larger graph initializers
@ -150,8 +148,21 @@ def chain_model(args):
else [node]
)
beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers)
beam_model = helper.make_model(beam_graph, producer_name="onnxruntime.transformers", opset_imports=opset_import)
assert decoder_model.ir_version == encoder_model.ir_version
logger.info(f"Using IR version {decoder_model.ir_version} for chained model")
# Set IR version of chained model to IR version of subgraphs in order to generate a working E2E model
beam_model = helper.make_model_gen_version(
beam_graph,
producer_name="onnxruntime.transformers",
opset_imports=opset_import,
ir_version=decoder_model.ir_version,
)
if os.path.isfile(args.beam_model_output_dir):
logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}")
os.remove(args.beam_model_output_dir)
os.remove(args.beam_model_output_dir + ".data")
onnx.save(
beam_model,
args.beam_model_output_dir,

View file

@ -28,20 +28,18 @@ logger = logging.getLogger(__name__)
class WhisperDecoderInit(torch.nn.Module):
"""A Whisper decoder with LM head to create initial past key values.
"""A Whisper decoder to create initial past key values.
This model is only called once during starting decoding.
"""
def __init__(
self,
decoder: torch.nn.Module,
lm_head: torch.nn.Module,
config: WhisperConfig,
decoder_start_token_id: int = None,
):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config
self.decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
@ -70,12 +68,11 @@ class WhisperDecoderInit(torch.nn.Module):
class WhisperDecoder(torch.nn.Module):
"""A Whisper decoder with LM head and past key values"""
"""A Whisper decoder with past key values"""
def __init__(self, decoder, lm_head, config):
def __init__(self, decoder, config):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config
def forward(self, decoder_input_ids, *past):

View file

@ -35,14 +35,13 @@ class WhisperEncoderDecoderInit(torch.nn.Module):
self,
encoder: torch.nn.Module,
decoder: torch.nn.Module,
lm_head: torch.nn.Module,
config: WhisperConfig,
decoder_start_token_id: Optional[int] = None,
):
super().__init__()
self.config = config
self.whisper_encoder = WhisperEncoder(encoder, config)
self.whisper_decoder_init = WhisperDecoderInit(decoder, lm_head, config, decoder_start_token_id)
self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id)
def forward(
self,

View file

@ -10,8 +10,10 @@ import sys
from pathlib import Path
from typing import Dict, Tuple, Union
import numpy as np
import torch
from transformers import WhisperForConditionalGeneration
from datasets import load_dataset
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
@ -62,7 +64,7 @@ class WhisperHelper:
if os.path.isdir(model_name_or_path):
model_name = Path(model_name_or_path).parts[-1]
else:
model_name.split("/")[-1]
model_name = model_name.split("/")[-1]
model_name += suffix
@ -91,14 +93,13 @@ class WhisperHelper:
if state_dict_path:
model.load_state_dict(torch.load(state_dict_path), strict=False)
decoder = WhisperDecoder(model, None, model.config)
decoder = WhisperDecoder(model, model.config)
decoder.eval().to(device)
if merge_encoder_and_decoder_init:
encoder_decoder_init = WhisperEncoderDecoderInit(
model,
model,
None,
model.config,
decoder_start_token_id=None,
)
@ -106,7 +107,7 @@ class WhisperHelper:
else:
encoder = WhisperEncoder(model.model.encoder, model.config)
encoder.eval().to(device)
decoder_init = WhisperDecoderInit(model.decoder, None, model.config)
decoder_init = WhisperDecoderInit(model.decoder, model.config)
decoder_init.eval().to(device)
return {
"encoder": encoder,
@ -261,11 +262,62 @@ class WhisperHelper:
@staticmethod
def verify_onnx(
model: Union[WhisperEncoder, WhisperDecoder, WhisperDecoderInit, WhisperEncoderDecoderInit],
model_name_or_path: str,
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
# Not implemented for Whisper currently
return 0
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path)
config = WhisperConfig.from_pretrained(model_name_or_path)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1
length_penalty, repetition_penalty = 1.0, 1.0
inputs = {
"input_features": input_features.to(device),
"max_length": max_length,
"min_length": min_length,
"num_beams": num_beams,
"num_return_sequences": num_return_sequences,
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"early_stopping": True,
"use_cache": True,
}
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
del inputs["early_stopping"]
del inputs["use_cache"]
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
ort_to_np = {
"tensor(float)": np.float32,
"tensor(float16)": np.float16,
"tensor(int64)": np.int64,
"tensor(int32)": np.int32,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
}
for name, dtype in zip(ort_names, ort_dtypes):
if name == "input_features":
inputs[name] = inputs[name].detach().cpu().numpy()
elif name == "vocab_mask":
inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype])
elif name == "prefix_vocab_mask":
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
elif name == "decoder_input_ids":
inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype])
elif name == "logits_processor":
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
else:
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
ort_outputs = ort_session.run(None, inputs)[0][0]
if pt_outputs.shape != ort_outputs.shape:
logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape")
diff = pt_outputs - ort_outputs
return max(diff.min(), diff.max(), key=abs)