mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
add local state dict option (#15759)
### Description Adds an option to load local state dictionary for whisper model export. ### Motivation and Context This is useful to demonstrate workflow of using ORT Training to get model weights, downloading said weights onto a local gpu-enabled device, exporting the custom model using `convert_to_onnx.py`, and then nicely feeding the .onnx file into ORT InferenceSession.
This commit is contained in:
parent
391f897983
commit
090312af71
2 changed files with 18 additions and 2 deletions
|
|
@ -159,6 +159,13 @@ def parse_arguments():
|
|||
|
||||
parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3")
|
||||
|
||||
parser.add_argument(
|
||||
"--state_dict_path",
|
||||
type=str,
|
||||
default="",
|
||||
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
|
@ -181,10 +188,13 @@ def export_onnx_models(
|
|||
quantize_embedding_layer: bool = False,
|
||||
quantize_per_channel: bool = False,
|
||||
quantize_reduce_range: bool = False,
|
||||
state_dict_path: str = "",
|
||||
):
|
||||
device = torch.device("cuda:0" if use_gpu else "cpu")
|
||||
|
||||
models = WhisperHelper.load_model(model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init)
|
||||
models = WhisperHelper.load_model(
|
||||
model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path
|
||||
)
|
||||
config = models["decoder"].config
|
||||
|
||||
if (not use_external_data_format) and (config.num_layers > 24):
|
||||
|
|
|
|||
|
|
@ -71,7 +71,11 @@ class WhisperHelper:
|
|||
|
||||
@staticmethod
|
||||
def load_model(
|
||||
model_name_or_path: str, cache_dir: str, device: torch.device, merge_encoder_and_decoder_init: bool = True
|
||||
model_name_or_path: str,
|
||||
cache_dir: str,
|
||||
device: torch.device,
|
||||
merge_encoder_and_decoder_init: bool = True,
|
||||
state_dict_path: str = "",
|
||||
) -> Dict[str, torch.nn.Module]:
|
||||
"""Load model given a pretrained name or path, then build models for ONNX conversion.
|
||||
|
||||
|
|
@ -84,6 +88,8 @@ class WhisperHelper:
|
|||
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
|
||||
"""
|
||||
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
if state_dict_path:
|
||||
model.load_state_dict(torch.load(state_dict_path), strict=False)
|
||||
|
||||
decoder = WhisperDecoder(model, None, model.config)
|
||||
decoder.eval().to(device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue