mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Fix token_id values for whisper export (#15362)
### Description The current ONNX export of Whisper utilizes hard-coded values for token_ids when configuring the BeamSearch node. This PR removes these literals and instead takes these values straight from the WhisperConfig. ### Motivation and Context Hard-coding these values can cause some parity issues when comparing to default PyTorch behavior - this change to take from WhisperConfig resolves these. Co-authored-by: Peter McAughan <petermca@microsoft.com>
This commit is contained in:
parent
55495cc809
commit
d0cca91cfb
1 changed files with 3 additions and 4 deletions
|
|
@ -22,7 +22,6 @@ def chain_model(args):
|
|||
add_attention_mask(decoder_model)
|
||||
|
||||
config = WhisperConfig.from_pretrained(args.model_name_or_path)
|
||||
pad_token_id = config.pad_token_id
|
||||
|
||||
beam_inputs = [
|
||||
"input_features",
|
||||
|
|
@ -42,9 +41,9 @@ def chain_model(args):
|
|||
node.domain = "com.microsoft"
|
||||
node.attribute.extend(
|
||||
[
|
||||
helper.make_attribute("eos_token_id", 50256),
|
||||
helper.make_attribute("pad_token_id", pad_token_id),
|
||||
helper.make_attribute("decoder_start_token_id", 50257),
|
||||
helper.make_attribute("eos_token_id", config.eos_token_id),
|
||||
helper.make_attribute("pad_token_id", config.pad_token_id),
|
||||
helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id),
|
||||
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
|
||||
helper.make_attribute("early_stopping", True),
|
||||
helper.make_attribute("model_type", 2),
|
||||
|
|
|
|||
Loading…
Reference in a new issue