From 090312af711ddae1a3936c8dde180551b3ccefbb Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 1 May 2023 22:08:11 -0700 Subject: [PATCH] add local state dict option (#15759) ### Description Adds an option to load local state dictionary for whisper model export. ### Motivation and Context This is useful to demonstrate workflow of using ORT Training to get model weights, downloading said weights onto a local gpu-enabled device, exporting the custom model using `convert_to_onnx.py`, and then nicely feeding the .onnx file into ORT InferenceSession. --- .../transformers/models/whisper/convert_to_onnx.py | 12 +++++++++++- .../transformers/models/whisper/whisper_helper.py | 8 +++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index e8df6bc785..5dd848cec6 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -159,6 +159,13 @@ def parse_arguments(): parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3") + parser.add_argument( + "--state_dict_path", + type=str, + default="", + help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", + ) + args = parser.parse_args() return args @@ -181,10 +188,13 @@ def export_onnx_models( quantize_embedding_layer: bool = False, quantize_per_channel: bool = False, quantize_reduce_range: bool = False, + state_dict_path: str = "", ): device = torch.device("cuda:0" if use_gpu else "cpu") - models = WhisperHelper.load_model(model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init) + models = WhisperHelper.load_model( + model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path + ) config = models["decoder"].config if (not use_external_data_format) and (config.num_layers > 24): diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index c795e36498..9f18984d61 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -71,7 +71,11 @@ class WhisperHelper: @staticmethod def load_model( - model_name_or_path: str, cache_dir: str, device: torch.device, merge_encoder_and_decoder_init: bool = True + model_name_or_path: str, + cache_dir: str, + device: torch.device, + merge_encoder_and_decoder_init: bool = True, + state_dict_path: str = "", ) -> Dict[str, torch.nn.Module]: """Load model given a pretrained name or path, then build models for ONNX conversion. @@ -84,6 +88,8 @@ class WhisperHelper: Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. """ model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir) + if state_dict_path: + model.load_state_dict(torch.load(state_dict_path), strict=False) decoder = WhisperDecoder(model, None, model.config) decoder.eval().to(device)