mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Update Whisper export with beam search (#19322)
### Description This PR updates the Whisper export with beam search by adding the following. - Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the Whisper with beam search model - Sets the default PyTorch attention implementation to `eager` to allow existing attention fusions to continue working - Re-uses the cache directory when loading the PyTorch model to reduce memory used on disk - Adds `--disable_auto_mixed_precision` to the example FP16 export command ### Motivation and Context - [This PR](https://github.com/microsoft/onnxruntime/pull/19112) added the `is_unidirectional` parameter to `CheckInputs`, but it was not provided when checking the inputs in `DecoderMaskedMultiHeadAttention`. - [This PR](https://github.com/microsoft/onnxruntime/pull/19200) explains the reasoning behind why `eager` is used to load the `WhisperAttention` class. - By re-using the cache directory for loading the PyTorch model, only one copy of the PyTorch model is saved on disk instead of two copies. - By providing this flag, there will be less Cast nodes in the Whisper with beam search model to switch between FP16 and FP32 precision.
This commit is contained in:
parent
3454f86e70
commit
febec1c586
4 changed files with 18 additions and 5 deletions
|
|
@ -74,6 +74,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault<bool>(
|
||||
attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);
|
||||
|
||||
bool is_unidirectional = false;
|
||||
bool is_dmmha_packing = (key == nullptr && value == nullptr);
|
||||
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
|
||||
key,
|
||||
|
|
@ -88,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
num_heads_,
|
||||
mask_filter_value_,
|
||||
scale_,
|
||||
is_unidirectional,
|
||||
past_present_share_buffer_,
|
||||
is_dmmha_packing, // dmmha_packing
|
||||
device_prop.maxThreadsPerBlock));
|
||||
|
|
|
|||
|
|
@ -60,10 +60,10 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w
|
|||
Export + Optimize for FP16 and GPU
|
||||
```
|
||||
# From source:
|
||||
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
|
||||
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
|
||||
|
||||
# From wheel:
|
||||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
|
||||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
|
||||
```
|
||||
|
||||
Export + Quantize for INT8
|
||||
|
|
|
|||
|
|
@ -478,7 +478,7 @@ def main(argv=None):
|
|||
# Wrap parity check in try-except to allow export to continue in case this produces an error
|
||||
try:
|
||||
with torch.no_grad():
|
||||
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device)
|
||||
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
|
||||
if max_diff > 1e-4:
|
||||
logger.warning("PyTorch and ONNX Runtime results are NOT close")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ from typing import Dict, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
|
||||
from transformers import __version__ as transformers_version
|
||||
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
|
||||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
|
||||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
|
||||
|
|
@ -88,7 +90,10 @@ class WhisperHelper:
|
|||
Returns:
|
||||
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
|
||||
"""
|
||||
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
extra_kwargs = {}
|
||||
if version.parse(transformers_version) >= version.parse("4.36.0"):
|
||||
extra_kwargs["attn_implementation"] = "eager"
|
||||
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs)
|
||||
if state_dict_path:
|
||||
model.load_state_dict(torch.load(state_dict_path), strict=False)
|
||||
|
||||
|
|
@ -262,11 +267,17 @@ class WhisperHelper:
|
|||
@staticmethod
|
||||
def verify_onnx(
|
||||
model_name_or_path: str,
|
||||
cache_dir: str,
|
||||
ort_session: InferenceSession,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
|
||||
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device)
|
||||
extra_kwargs = {}
|
||||
if version.parse(transformers_version) >= version.parse("4.36.0"):
|
||||
extra_kwargs["attn_implementation"] = "eager"
|
||||
pt_model = WhisperForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
|
||||
).to(device)
|
||||
processor = WhisperProcessor.from_pretrained(model_name_or_path)
|
||||
config = WhisperConfig.from_pretrained(model_name_or_path)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue