mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Support generation script with custom eos/pad token id (#14113)
### Description <!-- Describe your changes. --> when custom decoder onnx model passes in, user can specify eos/pad token id instead of populating from torch config. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
e5e3570ac5
commit
821baa5b83
1 changed files with 21 additions and 0 deletions
|
|
@ -363,6 +363,22 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
|
|||
help="Vocab_size of the underlying model used to decide the shape of vocab mask",
|
||||
)
|
||||
|
||||
beam_parameters_group.add_argument(
|
||||
"--eos_token_id",
|
||||
type=int,
|
||||
required=False,
|
||||
default=-1,
|
||||
help="custom eos_token_id for generating model with existing onnx encoder/decoder",
|
||||
)
|
||||
|
||||
beam_parameters_group.add_argument(
|
||||
"--pad_token_id",
|
||||
type=int,
|
||||
required=False,
|
||||
default=-1,
|
||||
help="custom pad_token_id for generating model with existing onnx encoder/decoder",
|
||||
)
|
||||
|
||||
test_group = parser.add_argument_group("Other options for testing parity and performance")
|
||||
|
||||
test_group.add_argument(
|
||||
|
|
@ -1374,6 +1390,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
if args.vocab_size != -1:
|
||||
vocab_size = args.vocab_size
|
||||
|
||||
if args.eos_token_id != -1:
|
||||
eos_token_id = args.eos_token_id
|
||||
if args.pad_token_id != -1:
|
||||
pad_token_id = args.pad_token_id
|
||||
|
||||
decoder_model = onnx.load_model(args.decoder_onnx, load_external_data=True)
|
||||
decoder_model.graph.name = f"{args.model_type} decoder"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue