Support generation script with custom eos/pad token id (#14113)

### Description
<!-- Describe your changes. -->

when custom decoder onnx model passes in, user can specify eos/pad token
id instead of populating from torch config.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Ye Wang 2023-01-04 10:51:53 -08:00 committed by GitHub
parent e5e3570ac5
commit 821baa5b83
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -363,6 +363,22 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
help="Vocab_size of the underlying model used to decide the shape of vocab mask",
)
beam_parameters_group.add_argument(
"--eos_token_id",
type=int,
required=False,
default=-1,
help="custom eos_token_id for generating model with existing onnx encoder/decoder",
)
beam_parameters_group.add_argument(
"--pad_token_id",
type=int,
required=False,
default=-1,
help="custom pad_token_id for generating model with existing onnx encoder/decoder",
)
test_group = parser.add_argument_group("Other options for testing parity and performance")
test_group.add_argument(
@ -1374,6 +1390,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
if args.vocab_size != -1:
vocab_size = args.vocab_size
if args.eos_token_id != -1:
eos_token_id = args.eos_token_id
if args.pad_token_id != -1:
pad_token_id = args.pad_token_id
decoder_model = onnx.load_model(args.decoder_onnx, load_external_data=True)
decoder_model.graph.name = f"{args.model_type} decoder"