mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
add option to load pretrained weights for T5 model (#15951)
### Description <!-- Describe your changes. --> Adds option to pass in pretrained weights file during T5 inference onnx export. Mimics the changes made to whisper: https://github.com/microsoft/onnxruntime/pull/15759 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Required for ONNX Runtime demo being presented at BUILD.
This commit is contained in:
parent
e96f10d27b
commit
a0ccb95f3c
2 changed files with 15 additions and 1 deletions
12
onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py
Normal file → Executable file
12
onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py
Normal file → Executable file
|
|
@ -127,6 +127,13 @@ def parse_arguments():
|
|||
)
|
||||
parser.set_defaults(use_int64_inputs=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--state_dict_path",
|
||||
type=str,
|
||||
default="",
|
||||
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
|
@ -147,10 +154,13 @@ def export_onnx_models(
|
|||
disable_auto_mixed_precision: bool = False,
|
||||
use_int32_inputs: bool = True,
|
||||
model_type: str = "t5",
|
||||
state_dict_path: str = "",
|
||||
):
|
||||
device = torch.device("cuda:0" if use_gpu else "cpu")
|
||||
|
||||
models = T5Helper.load_model(model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type)
|
||||
models = T5Helper.load_model(
|
||||
model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type, state_dict_path
|
||||
)
|
||||
config = models["decoder"].config
|
||||
|
||||
if (not use_external_data_format) and (config.num_layers > 24):
|
||||
|
|
|
|||
4
onnxruntime/python/tools/transformers/models/t5/t5_helper.py
Normal file → Executable file
4
onnxruntime/python/tools/transformers/models/t5/t5_helper.py
Normal file → Executable file
|
|
@ -66,6 +66,7 @@ class T5Helper:
|
|||
device: torch.device,
|
||||
merge_encoder_and_decoder_init: bool = True,
|
||||
model_type: str = "t5",
|
||||
state_dict_path: str = "",
|
||||
) -> Dict[str, torch.nn.Module]:
|
||||
"""Load model given a pretrained name or path, then build models for ONNX conversion.
|
||||
|
||||
|
|
@ -85,6 +86,9 @@ class T5Helper:
|
|||
else:
|
||||
raise ValueError("only support mode_type=t5 or mt5")
|
||||
|
||||
if state_dict_path:
|
||||
model.load_state_dict(torch.load(state_dict_path))
|
||||
|
||||
decoder = T5Decoder(model.decoder, model.lm_head, model.config)
|
||||
decoder.eval().to(device)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue