add option to load pretrained weights for T5 model (#15951)

### Description
<!-- Describe your changes. -->

Adds option to pass in pretrained weights file during T5 inference onnx
export. Mimics the changes made to whisper:
https://github.com/microsoft/onnxruntime/pull/15759

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Required for ONNX Runtime demo being presented at BUILD.
This commit is contained in:
Prathik Rao 2023-05-15 22:52:35 -07:00 committed by GitHub
parent e96f10d27b
commit a0ccb95f3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 1 deletions

View file

@ -127,6 +127,13 @@ def parse_arguments():
)
parser.set_defaults(use_int64_inputs=False)
parser.add_argument(
"--state_dict_path",
type=str,
default="",
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
)
args = parser.parse_args()
return args
@ -147,10 +154,13 @@ def export_onnx_models(
disable_auto_mixed_precision: bool = False,
use_int32_inputs: bool = True,
model_type: str = "t5",
state_dict_path: str = "",
):
device = torch.device("cuda:0" if use_gpu else "cpu")
models = T5Helper.load_model(model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type)
models = T5Helper.load_model(
model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type, state_dict_path
)
config = models["decoder"].config
if (not use_external_data_format) and (config.num_layers > 24):

View file

@ -66,6 +66,7 @@ class T5Helper:
device: torch.device,
merge_encoder_and_decoder_init: bool = True,
model_type: str = "t5",
state_dict_path: str = "",
) -> Dict[str, torch.nn.Module]:
"""Load model given a pretrained name or path, then build models for ONNX conversion.
@ -85,6 +86,9 @@ class T5Helper:
else:
raise ValueError("only support mode_type=t5 or mt5")
if state_dict_path:
model.load_state_dict(torch.load(state_dict_path))
decoder = T5Decoder(model.decoder, model.lm_head, model.config)
decoder.eval().to(device)