Update Whisper export with beam search (#19322)

### Description
This PR updates the Whisper export with beam search by adding the
following.

- Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the
Whisper with beam search model
- Sets the default PyTorch attention implementation to `eager` to allow
existing attention fusions to continue working
- Re-uses the cache directory when loading the PyTorch model to reduce
memory used on disk
- Adds `--disable_auto_mixed_precision` to the example FP16 export
command

### Motivation and Context
- [This PR](https://github.com/microsoft/onnxruntime/pull/19112) added
the `is_unidirectional` parameter to `CheckInputs`, but it was not
provided when checking the inputs in `DecoderMaskedMultiHeadAttention`.
- [This PR](https://github.com/microsoft/onnxruntime/pull/19200)
explains the reasoning behind why `eager` is used to load the
`WhisperAttention` class.
- By re-using the cache directory for loading the PyTorch model, only
one copy of the PyTorch model is saved on disk instead of two copies.
- By providing this flag, there will be less Cast nodes in the Whisper
with beam search model to switch between FP16 and FP32 precision.
This commit is contained in:
kunal-vaishnavi 2024-01-30 11:59:15 -08:00 committed by GitHub
parent 3454f86e70
commit febec1c586
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 18 additions and 5 deletions

View file

@ -74,6 +74,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault<bool>( parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault<bool>(
attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);
bool is_unidirectional = false;
bool is_dmmha_packing = (key == nullptr && value == nullptr); bool is_dmmha_packing = (key == nullptr && value == nullptr);
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query, ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key, key,
@ -88,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
num_heads_, num_heads_,
mask_filter_value_, mask_filter_value_,
scale_, scale_,
is_unidirectional,
past_present_share_buffer_, past_present_share_buffer_,
is_dmmha_packing, // dmmha_packing is_dmmha_packing, // dmmha_packing
device_prop.maxThreadsPerBlock)); device_prop.maxThreadsPerBlock));

View file

@ -60,10 +60,10 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w
Export + Optimize for FP16 and GPU Export + Optimize for FP16 and GPU
``` ```
# From source: # From source:
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
# From wheel: # From wheel:
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
``` ```
Export + Quantize for INT8 Export + Quantize for INT8

View file

@ -478,7 +478,7 @@ def main(argv=None):
# Wrap parity check in try-except to allow export to continue in case this produces an error # Wrap parity check in try-except to allow export to continue in case this produces an error
try: try:
with torch.no_grad(): with torch.no_grad():
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device) max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
if max_diff > 1e-4: if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close") logger.warning("PyTorch and ONNX Runtime results are NOT close")
else: else:

View file

@ -12,7 +12,9 @@ from typing import Dict, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
from transformers import __version__ as transformers_version
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
@ -88,7 +90,10 @@ class WhisperHelper:
Returns: Returns:
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
""" """
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir) extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs)
if state_dict_path: if state_dict_path:
model.load_state_dict(torch.load(state_dict_path), strict=False) model.load_state_dict(torch.load(state_dict_path), strict=False)
@ -262,11 +267,17 @@ class WhisperHelper:
@staticmethod @staticmethod
def verify_onnx( def verify_onnx(
model_name_or_path: str, model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession, ort_session: InferenceSession,
device: torch.device, device: torch.device,
): ):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device) extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
pt_model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path) processor = WhisperProcessor.from_pretrained(model_name_or_path)
config = WhisperConfig.from_pretrained(model_name_or_path) config = WhisperConfig.from_pretrained(model_name_or_path)