mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Update Whisper export with beam search (#19322)
### Description This PR updates the Whisper export with beam search by adding the following. - Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the Whisper with beam search model - Sets the default PyTorch attention implementation to `eager` to allow existing attention fusions to continue working - Re-uses the cache directory when loading the PyTorch model to reduce memory used on disk - Adds `--disable_auto_mixed_precision` to the example FP16 export command ### Motivation and Context - [This PR](https://github.com/microsoft/onnxruntime/pull/19112) added the `is_unidirectional` parameter to `CheckInputs`, but it was not provided when checking the inputs in `DecoderMaskedMultiHeadAttention`. - [This PR](https://github.com/microsoft/onnxruntime/pull/19200) explains the reasoning behind why `eager` is used to load the `WhisperAttention` class. - By re-using the cache directory for loading the PyTorch model, only one copy of the PyTorch model is saved on disk instead of two copies. - By providing this flag, there will be less Cast nodes in the Whisper with beam search model to switch between FP16 and FP32 precision.
This commit is contained in:
parent
3454f86e70
commit
febec1c586
4 changed files with 18 additions and 5 deletions
|
|
@ -74,6 +74,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
||||||
parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault<bool>(
|
parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault<bool>(
|
||||||
attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);
|
attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);
|
||||||
|
|
||||||
|
bool is_unidirectional = false;
|
||||||
bool is_dmmha_packing = (key == nullptr && value == nullptr);
|
bool is_dmmha_packing = (key == nullptr && value == nullptr);
|
||||||
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
|
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
|
||||||
key,
|
key,
|
||||||
|
|
@ -88,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
||||||
num_heads_,
|
num_heads_,
|
||||||
mask_filter_value_,
|
mask_filter_value_,
|
||||||
scale_,
|
scale_,
|
||||||
|
is_unidirectional,
|
||||||
past_present_share_buffer_,
|
past_present_share_buffer_,
|
||||||
is_dmmha_packing, // dmmha_packing
|
is_dmmha_packing, // dmmha_packing
|
||||||
device_prop.maxThreadsPerBlock));
|
device_prop.maxThreadsPerBlock));
|
||||||
|
|
|
||||||
|
|
@ -60,10 +60,10 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w
|
||||||
Export + Optimize for FP16 and GPU
|
Export + Optimize for FP16 and GPU
|
||||||
```
|
```
|
||||||
# From source:
|
# From source:
|
||||||
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
|
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
|
||||||
|
|
||||||
# From wheel:
|
# From wheel:
|
||||||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
|
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
|
||||||
```
|
```
|
||||||
|
|
||||||
Export + Quantize for INT8
|
Export + Quantize for INT8
|
||||||
|
|
|
||||||
|
|
@ -478,7 +478,7 @@ def main(argv=None):
|
||||||
# Wrap parity check in try-except to allow export to continue in case this produces an error
|
# Wrap parity check in try-except to allow export to continue in case this produces an error
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device)
|
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
|
||||||
if max_diff > 1e-4:
|
if max_diff > 1e-4:
|
||||||
logger.warning("PyTorch and ONNX Runtime results are NOT close")
|
logger.warning("PyTorch and ONNX Runtime results are NOT close")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,9 @@ from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
|
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
|
||||||
|
from transformers import __version__ as transformers_version
|
||||||
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
|
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
|
||||||
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
|
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
|
||||||
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
|
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
|
||||||
|
|
@ -88,7 +90,10 @@ class WhisperHelper:
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
|
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
|
||||||
"""
|
"""
|
||||||
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
extra_kwargs = {}
|
||||||
|
if version.parse(transformers_version) >= version.parse("4.36.0"):
|
||||||
|
extra_kwargs["attn_implementation"] = "eager"
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs)
|
||||||
if state_dict_path:
|
if state_dict_path:
|
||||||
model.load_state_dict(torch.load(state_dict_path), strict=False)
|
model.load_state_dict(torch.load(state_dict_path), strict=False)
|
||||||
|
|
||||||
|
|
@ -262,11 +267,17 @@ class WhisperHelper:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_onnx(
|
def verify_onnx(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
|
cache_dir: str,
|
||||||
ort_session: InferenceSession,
|
ort_session: InferenceSession,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
|
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
|
||||||
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device)
|
extra_kwargs = {}
|
||||||
|
if version.parse(transformers_version) >= version.parse("4.36.0"):
|
||||||
|
extra_kwargs["attn_implementation"] = "eager"
|
||||||
|
pt_model = WhisperForConditionalGeneration.from_pretrained(
|
||||||
|
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
|
||||||
|
).to(device)
|
||||||
processor = WhisperProcessor.from_pretrained(model_name_or_path)
|
processor = WhisperProcessor.from_pretrained(model_name_or_path)
|
||||||
config = WhisperConfig.from_pretrained(model_name_or_path)
|
config = WhisperConfig.from_pretrained(model_name_or_path)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue