diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 596c556cfc..8c3ecceeff 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -363,6 +363,22 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: help="Vocab_size of the underlying model used to decide the shape of vocab mask", ) + beam_parameters_group.add_argument( + "--eos_token_id", + type=int, + required=False, + default=-1, + help="custom eos_token_id for generating model with existing onnx encoder/decoder", + ) + + beam_parameters_group.add_argument( + "--pad_token_id", + type=int, + required=False, + default=-1, + help="custom pad_token_id for generating model with existing onnx encoder/decoder", + ) + test_group = parser.add_argument_group("Other options for testing parity and performance") test_group.add_argument( @@ -1374,6 +1390,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati if args.vocab_size != -1: vocab_size = args.vocab_size + if args.eos_token_id != -1: + eos_token_id = args.eos_token_id + if args.pad_token_id != -1: + pad_token_id = args.pad_token_id + decoder_model = onnx.load_model(args.decoder_onnx, load_external_data=True) decoder_model.graph.name = f"{args.model_type} decoder"