From d0cca91cfb1b4737efb2f198429439f91d2c2ee7 Mon Sep 17 00:00:00 2001 From: petermcaughan Date: Thu, 6 Apr 2023 11:01:21 -0700 Subject: [PATCH] Fix token_id values for whisper export (#15362) ### Description The current ONNX export of Whisper utilizes hard-coded values for token_ids when configuring the BeamSearch node. This PR removes these literals and instead takes these values straight from the WhisperConfig. ### Motivation and Context Hard-coding these values can cause some parity issues when comparing to default PyTorch behavior - this change to take from WhisperConfig resolves these. Co-authored-by: Peter McAughan --- .../tools/transformers/models/whisper/whisper_chain.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index b9e17d6ef1..c60260e3ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -22,7 +22,6 @@ def chain_model(args): add_attention_mask(decoder_model) config = WhisperConfig.from_pretrained(args.model_name_or_path) - pad_token_id = config.pad_token_id beam_inputs = [ "input_features", @@ -42,9 +41,9 @@ def chain_model(args): node.domain = "com.microsoft" node.attribute.extend( [ - helper.make_attribute("eos_token_id", 50256), - helper.make_attribute("pad_token_id", pad_token_id), - helper.make_attribute("decoder_start_token_id", 50257), + helper.make_attribute("eos_token_id", config.eos_token_id), + helper.make_attribute("pad_token_id", config.pad_token_id), + helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), helper.make_attribute("early_stopping", True), helper.make_attribute("model_type", 2),