Update Whisper export with beam search (#19322)

### Description
This PR updates the Whisper export with beam search by adding the
following.

- Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the
Whisper with beam search model
- Sets the default PyTorch attention implementation to `eager` to allow
existing attention fusions to continue working
- Re-uses the cache directory when loading the PyTorch model to reduce
memory used on disk
- Adds `--disable_auto_mixed_precision` to the example FP16 export
command

### Motivation and Context
- [This PR](https://github.com/microsoft/onnxruntime/pull/19112) added
the `is_unidirectional` parameter to `CheckInputs`, but it was not
provided when checking the inputs in `DecoderMaskedMultiHeadAttention`.
- [This PR](https://github.com/microsoft/onnxruntime/pull/19200)
explains the reasoning behind why `eager` is used to load the
`WhisperAttention` class.
- By re-using the cache directory for loading the PyTorch model, only
one copy of the PyTorch model is saved on disk instead of two copies.
- By providing this flag, there will be less Cast nodes in the Whisper
with beam search model to switch between FP16 and FP32 precision.
This commit is contained in:
kunal-vaishnavi 2024-01-30 11:59:15 -08:00 committed by GitHub
parent 3454f86e70
commit febec1c586
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 18 additions and 5 deletions

View file

@ -74,6 +74,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault<bool>(
attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);
bool is_unidirectional = false;
bool is_dmmha_packing = (key == nullptr && value == nullptr);
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
@ -88,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
num_heads_,
mask_filter_value_,
scale_,
is_unidirectional,
past_present_share_buffer_,
is_dmmha_packing, // dmmha_packing
device_prop.maxThreadsPerBlock));

View file

@ -60,10 +60,10 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w
Export + Optimize for FP16 and GPU
```
# From source:
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
# From wheel:
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
```
Export + Quantize for INT8

View file

@ -478,7 +478,7 @@ def main(argv=None):
# Wrap parity check in try-except to allow export to continue in case this produces an error
try:
with torch.no_grad():
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device)
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close")
else:

View file

@ -12,7 +12,9 @@ from typing import Dict, Tuple, Union
import numpy as np
import torch
from packaging import version
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
from transformers import __version__ as transformers_version
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
@ -88,7 +90,10 @@ class WhisperHelper:
Returns:
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
"""
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs)
if state_dict_path:
model.load_state_dict(torch.load(state_dict_path), strict=False)
@ -262,11 +267,17 @@ class WhisperHelper:
@staticmethod
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
device: torch.device,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device)
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
pt_model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path)
config = WhisperConfig.from_pretrained(model_name_or_path)