Fix token_id values for whisper export (#15362)

### Description
The current ONNX export of Whisper utilizes hard-coded values for
token_ids when configuring the BeamSearch node. This PR removes these
literals and instead takes these values straight from the WhisperConfig.



### Motivation and Context
Hard-coding these values can cause some parity issues when comparing to
default PyTorch behavior - this change to take from WhisperConfig
resolves these.

Co-authored-by: Peter McAughan <petermca@microsoft.com>
This commit is contained in:
petermcaughan 2023-04-06 11:01:21 -07:00 committed by GitHub
parent 55495cc809
commit d0cca91cfb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -22,7 +22,6 @@ def chain_model(args):
add_attention_mask(decoder_model)
config = WhisperConfig.from_pretrained(args.model_name_or_path)
pad_token_id = config.pad_token_id
beam_inputs = [
"input_features",
@ -42,9 +41,9 @@ def chain_model(args):
node.domain = "com.microsoft"
node.attribute.extend(
[
helper.make_attribute("eos_token_id", 50256),
helper.make_attribute("pad_token_id", pad_token_id),
helper.make_attribute("decoder_start_token_id", 50257),
helper.make_attribute("eos_token_id", config.eos_token_id),
helper.make_attribute("pad_token_id", config.pad_token_id),
helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id),
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
helper.make_attribute("early_stopping", True),
helper.make_attribute("model_type", 2),