add local state dict option (#15759)

### Description

Adds an option to load local state dictionary for whisper model export.

### Motivation and Context

This is useful to demonstrate workflow of using ORT Training to get
model weights, downloading said weights onto a local gpu-enabled device,
exporting the custom model using `convert_to_onnx.py`, and then nicely
feeding the .onnx file into ORT InferenceSession.
This commit is contained in:
Prathik Rao 2023-05-01 22:08:11 -07:00 committed by GitHub
parent 391f897983
commit 090312af71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 2 deletions

View file

@ -159,6 +159,13 @@ def parse_arguments():
parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3")
parser.add_argument(
"--state_dict_path",
type=str,
default="",
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
)
args = parser.parse_args()
return args
@ -181,10 +188,13 @@ def export_onnx_models(
quantize_embedding_layer: bool = False,
quantize_per_channel: bool = False,
quantize_reduce_range: bool = False,
state_dict_path: str = "",
):
device = torch.device("cuda:0" if use_gpu else "cpu")
models = WhisperHelper.load_model(model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init)
models = WhisperHelper.load_model(
model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path
)
config = models["decoder"].config
if (not use_external_data_format) and (config.num_layers > 24):

View file

@ -71,7 +71,11 @@ class WhisperHelper:
@staticmethod
def load_model(
model_name_or_path: str, cache_dir: str, device: torch.device, merge_encoder_and_decoder_init: bool = True
model_name_or_path: str,
cache_dir: str,
device: torch.device,
merge_encoder_and_decoder_init: bool = True,
state_dict_path: str = "",
) -> Dict[str, torch.nn.Module]:
"""Load model given a pretrained name or path, then build models for ONNX conversion.
@ -84,6 +88,8 @@ class WhisperHelper:
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
"""
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
if state_dict_path:
model.load_state_dict(torch.load(state_dict_path), strict=False)
decoder = WhisperDecoder(model, None, model.config)
decoder.eval().to(device)