mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Add ability to create ort format models from training offline utility (#16360)
This commit is contained in:
parent
0ad0d6ebbf
commit
42489a8a24
4 changed files with 92 additions and 30 deletions
|
|
@ -9,7 +9,8 @@ from typing import List, Optional, Union
|
|||
|
||||
import onnx
|
||||
|
||||
import onnxruntime.training.onnxblock as onnxblock
|
||||
from onnxruntime.tools.convert_onnx_models_to_ort import OptimizationStyle, convert_onnx_models_to_ort
|
||||
from onnxruntime.training import onnxblock
|
||||
|
||||
|
||||
class LossType(Enum):
|
||||
|
|
@ -58,7 +59,8 @@ def generate_artifacts(
|
|||
optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated.
|
||||
artifact_directory: The directory to save the generated artifacts.
|
||||
If None, the current working directory is used.
|
||||
prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used.
|
||||
prefix (str): The prefix to be used for the generated artifacts. If not specified, no prefix is used.
|
||||
ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
|
||||
|
|
@ -124,6 +126,10 @@ def generate_artifacts(
|
|||
training_model, eval_model = training_block.to_model_proto()
|
||||
model_params = training_block.parameters()
|
||||
|
||||
def _export_to_ort_format(model_path, output_dir, extra_options):
|
||||
if extra_options.get("ort_format", False):
|
||||
convert_onnx_models_to_ort(model_path, output_dir=output_dir, optimization_styles=[OptimizationStyle.Fixed])
|
||||
|
||||
if artifact_directory is None:
|
||||
artifact_directory = pathlib.Path.cwd()
|
||||
prefix = ""
|
||||
|
|
@ -137,12 +143,14 @@ def generate_artifacts(
|
|||
if os.path.exists(training_model_path):
|
||||
logging.info("Training model path %s already exists. Overwriting.", training_model_path)
|
||||
onnx.save(training_model, training_model_path)
|
||||
_export_to_ort_format(training_model_path, artifact_directory, extra_options)
|
||||
logging.info("Saved training model to %s", training_model_path)
|
||||
|
||||
eval_model_path = artifact_directory / f"{prefix}eval_model.onnx"
|
||||
if os.path.exists(eval_model_path):
|
||||
logging.info("Eval model path %s already exists. Overwriting.", eval_model_path)
|
||||
onnx.save(eval_model, eval_model_path)
|
||||
_export_to_ort_format(eval_model_path, artifact_directory, extra_options)
|
||||
logging.info("Saved eval model to %s", eval_model_path)
|
||||
|
||||
checkpoint_path = artifact_directory / f"{prefix}checkpoint"
|
||||
|
|
@ -173,4 +181,5 @@ def generate_artifacts(
|
|||
|
||||
optimizer_model_path = artifact_directory / f"{prefix}optimizer_model.onnx"
|
||||
onnx.save(optim_model, optimizer_model_path)
|
||||
_export_to_ort_format(optimizer_model_path, artifact_directory, extra_options)
|
||||
logging.info("Saved optimizer model to %s", optimizer_model_path)
|
||||
|
|
|
|||
|
|
@ -915,3 +915,26 @@ def test_label_encoder_composition():
|
|||
|
||||
all_nodes = [node.op_type for node in model.graph.node]
|
||||
assert "LabelEncoder" in all_nodes
|
||||
|
||||
|
||||
def test_save_ort_format():
|
||||
device = "cpu"
|
||||
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
|
||||
_, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
artifacts.generate_artifacts(
|
||||
base_model,
|
||||
requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"],
|
||||
loss=artifacts.LossType.CrossEntropyLoss,
|
||||
optimizer=artifacts.OptimType.AdamW,
|
||||
artifact_directory=temp_dir,
|
||||
ort_format=True,
|
||||
)
|
||||
|
||||
assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
|
||||
assert os.path.exists(os.path.join(temp_dir, "training_model.ort"))
|
||||
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
|
||||
assert os.path.exists(os.path.join(temp_dir, "eval_model.ort"))
|
||||
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
|
||||
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.ort"))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,17 @@
|
|||
# This script is a stub that uses the model conversion script from the util subdirectory.
|
||||
# We do it this way so we can use relative imports in that script, which makes it easy to include
|
||||
# in the ORT python package (where it must use relative imports)
|
||||
from util.convert_onnx_models_to_ort import convert_onnx_models_to_ort
|
||||
from util.convert_onnx_models_to_ort import convert_onnx_models_to_ort, parse_args
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert_onnx_models_to_ort()
|
||||
args = parse_args()
|
||||
convert_onnx_models_to_ort(
|
||||
args.model_path_or_dir,
|
||||
output_dir=args.output_dir,
|
||||
optimization_styles=args.optimization_style,
|
||||
custom_op_library_path=args.custom_op_library,
|
||||
target_platform=args.target_platform,
|
||||
save_optimized_onnx_model=args.save_optimized_onnx_model,
|
||||
allow_conversion_failures=args.allow_conversion_failures,
|
||||
enable_type_reduction=args.enable_type_reduction,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import enum
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import typing
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
|
|
@ -32,7 +33,7 @@ def _optimization_suffix(optimization_level_str: str, optimization_style: Optimi
|
|||
|
||||
def _create_config_file_path(
|
||||
model_path_or_dir: pathlib.Path,
|
||||
output_dir: typing.Optional[pathlib.Path],
|
||||
output_dir: pathlib.Path | None,
|
||||
optimization_level_str: str,
|
||||
optimization_style: OptimizationStyle,
|
||||
enable_type_reduction: bool,
|
||||
|
|
@ -57,7 +58,7 @@ def _create_session_options(
|
|||
optimization_level: ort.GraphOptimizationLevel,
|
||||
output_model_path: pathlib.Path,
|
||||
custom_op_library: pathlib.Path,
|
||||
session_options_config_entries: typing.Dict[str, str],
|
||||
session_options_config_entries: dict[str, str],
|
||||
):
|
||||
so = ort.SessionOptions()
|
||||
so.optimized_model_filepath = str(output_model_path)
|
||||
|
|
@ -74,15 +75,15 @@ def _create_session_options(
|
|||
|
||||
def _convert(
|
||||
model_path_or_dir: pathlib.Path,
|
||||
output_dir: typing.Optional[pathlib.Path],
|
||||
output_dir: pathlib.Path | None,
|
||||
optimization_level_str: str,
|
||||
optimization_style: OptimizationStyle,
|
||||
custom_op_library: pathlib.Path,
|
||||
create_optimized_onnx_model: bool,
|
||||
allow_conversion_failures: bool,
|
||||
target_platform: str,
|
||||
session_options_config_entries: typing.Dict[str, str],
|
||||
) -> typing.List[pathlib.Path]:
|
||||
session_options_config_entries: dict[str, str],
|
||||
) -> list[pathlib.Path]:
|
||||
model_dir = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent
|
||||
output_dir = output_dir or model_dir
|
||||
|
||||
|
|
@ -258,24 +259,33 @@ def parse_args():
|
|||
"processed.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
parsed_args = parser.parse_args()
|
||||
parsed_args.optimization_style = [OptimizationStyle[style_str] for style_str in parsed_args.optimization_style]
|
||||
return parsed_args
|
||||
|
||||
|
||||
def convert_onnx_models_to_ort():
|
||||
args = parse_args()
|
||||
def convert_onnx_models_to_ort(
|
||||
model_path_or_dir: pathlib.Path,
|
||||
output_dir: pathlib.Path | None = None,
|
||||
optimization_styles: list[OptimizationStyle] | None = None,
|
||||
custom_op_library_path: pathlib.Path | None = None,
|
||||
target_platform: str | None = None,
|
||||
save_optimized_onnx_model: bool = False,
|
||||
allow_conversion_failures: bool = False,
|
||||
enable_type_reduction: bool = False,
|
||||
):
|
||||
if output_dir is not None:
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
output_dir = output_dir.resolve(strict=True)
|
||||
|
||||
output_dir = None
|
||||
if args.output_dir is not None:
|
||||
if not args.output_dir.is_dir():
|
||||
args.output_dir.mkdir(parents=True)
|
||||
output_dir = args.output_dir.resolve(strict=True)
|
||||
optimization_styles = optimization_styles or []
|
||||
|
||||
optimization_styles = [OptimizationStyle[style_str] for style_str in args.optimization_style]
|
||||
# setting optimization level is not expected to be needed by typical users, but it can be set with this
|
||||
# environment variable
|
||||
optimization_level_str = os.getenv("ORT_CONVERT_ONNX_MODELS_TO_ORT_OPTIMIZATION_LEVEL", "all")
|
||||
model_path_or_dir = args.model_path_or_dir.resolve()
|
||||
custom_op_library = args.custom_op_library.resolve() if args.custom_op_library else None
|
||||
model_path_or_dir = model_path_or_dir.resolve()
|
||||
custom_op_library = custom_op_library_path.resolve() if custom_op_library_path else None
|
||||
|
||||
if not model_path_or_dir.is_dir() and not model_path_or_dir.is_file():
|
||||
raise FileNotFoundError(f"Model path '{model_path_or_dir}' is not a file or directory.")
|
||||
|
|
@ -285,7 +295,7 @@ def convert_onnx_models_to_ort():
|
|||
|
||||
session_options_config_entries = {}
|
||||
|
||||
if args.target_platform == "arm":
|
||||
if target_platform is not None and target_platform == "arm":
|
||||
session_options_config_entries["session.qdqisint8allowed"] = "1"
|
||||
else:
|
||||
session_options_config_entries["session.qdqisint8allowed"] = "0"
|
||||
|
|
@ -303,9 +313,9 @@ def convert_onnx_models_to_ort():
|
|||
optimization_level_str=optimization_level_str,
|
||||
optimization_style=optimization_style,
|
||||
custom_op_library=custom_op_library,
|
||||
create_optimized_onnx_model=args.save_optimized_onnx_model,
|
||||
allow_conversion_failures=args.allow_conversion_failures,
|
||||
target_platform=args.target_platform,
|
||||
create_optimized_onnx_model=save_optimized_onnx_model,
|
||||
allow_conversion_failures=allow_conversion_failures,
|
||||
target_platform=target_platform,
|
||||
session_options_config_entries=session_options_config_entries,
|
||||
)
|
||||
|
||||
|
|
@ -335,8 +345,8 @@ def convert_onnx_models_to_ort():
|
|||
optimization_style=OptimizationStyle.Fixed,
|
||||
custom_op_library=custom_op_library,
|
||||
create_optimized_onnx_model=False, # not useful as they would be created in a temp directory
|
||||
allow_conversion_failures=args.allow_conversion_failures,
|
||||
target_platform=args.target_platform,
|
||||
allow_conversion_failures=allow_conversion_failures,
|
||||
target_platform=target_platform,
|
||||
session_options_config_entries=session_options_config_entries_for_second_conversion,
|
||||
)
|
||||
|
||||
|
|
@ -351,11 +361,21 @@ def convert_onnx_models_to_ort():
|
|||
output_dir,
|
||||
optimization_level_str,
|
||||
optimization_style,
|
||||
args.enable_type_reduction,
|
||||
enable_type_reduction,
|
||||
)
|
||||
|
||||
create_config_from_models(converted_models, config_file, args.enable_type_reduction)
|
||||
create_config_from_models(converted_models, config_file, enable_type_reduction)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert_onnx_models_to_ort()
|
||||
args = parse_args()
|
||||
convert_onnx_models_to_ort(
|
||||
args.model_path_or_dir,
|
||||
output_dir=args.output_dir,
|
||||
optimization_styles=args.optimization_style,
|
||||
custom_op_library_path=args.custom_op_library,
|
||||
target_platform=args.target_platform,
|
||||
save_optimized_onnx_model=args.save_optimized_onnx_model,
|
||||
allow_conversion_failures=args.allow_conversion_failures,
|
||||
enable_type_reduction=args.enable_type_reduction,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue