From a0ccb95f3cabcd50ff1c85948819be083e25eb25 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 15 May 2023 22:52:35 -0700 Subject: [PATCH] add option to load pretrained weights for T5 model (#15951) ### Description Adds option to pass in pretrained weights file during T5 inference onnx export. Mimics the changes made to whisper: https://github.com/microsoft/onnxruntime/pull/15759 ### Motivation and Context Required for ONNX Runtime demo being presented at BUILD. --- .../tools/transformers/models/t5/convert_to_onnx.py | 12 +++++++++++- .../python/tools/transformers/models/t5/t5_helper.py | 4 ++++ 2 files changed, 15 insertions(+), 1 deletion(-) mode change 100644 => 100755 onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py mode change 100644 => 100755 onnxruntime/python/tools/transformers/models/t5/t5_helper.py diff --git a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py old mode 100644 new mode 100755 index eff24f58a0..230885ab6c --- a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py @@ -127,6 +127,13 @@ def parse_arguments(): ) parser.set_defaults(use_int64_inputs=False) + parser.add_argument( + "--state_dict_path", + type=str, + default="", + help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", + ) + args = parser.parse_args() return args @@ -147,10 +154,13 @@ def export_onnx_models( disable_auto_mixed_precision: bool = False, use_int32_inputs: bool = True, model_type: str = "t5", + state_dict_path: str = "", ): device = torch.device("cuda:0" if use_gpu else "cpu") - models = T5Helper.load_model(model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type) + models = T5Helper.load_model( + model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type, state_dict_path + ) config = models["decoder"].config if (not use_external_data_format) and (config.num_layers > 24): diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py old mode 100644 new mode 100755 index 4abf45ed1e..b1b494707a --- a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py @@ -66,6 +66,7 @@ class T5Helper: device: torch.device, merge_encoder_and_decoder_init: bool = True, model_type: str = "t5", + state_dict_path: str = "", ) -> Dict[str, torch.nn.Module]: """Load model given a pretrained name or path, then build models for ONNX conversion. @@ -85,6 +86,9 @@ class T5Helper: else: raise ValueError("only support mode_type=t5 or mt5") + if state_dict_path: + model.load_state_dict(torch.load(state_dict_path)) + decoder = T5Decoder(model.decoder, model.lm_head, model.config) decoder.eval().to(device)