diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index b9e17d6ef1..c60260e3ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -22,7 +22,6 @@ def chain_model(args): add_attention_mask(decoder_model) config = WhisperConfig.from_pretrained(args.model_name_or_path) - pad_token_id = config.pad_token_id beam_inputs = [ "input_features", @@ -42,9 +41,9 @@ def chain_model(args): node.domain = "com.microsoft" node.attribute.extend( [ - helper.make_attribute("eos_token_id", 50256), - helper.make_attribute("pad_token_id", pad_token_id), - helper.make_attribute("decoder_start_token_id", 50257), + helper.make_attribute("eos_token_id", config.eos_token_id), + helper.make_attribute("pad_token_id", config.pad_token_id), + helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), helper.make_attribute("early_stopping", True), helper.make_attribute("model_type", 2),