Add ability to create ort format models from training offline utility (#16360)

This commit is contained in:
Baiju Meswani 2023-06-21 18:51:43 -07:00 committed by GitHub
parent 0ad0d6ebbf
commit 42489a8a24
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 92 additions and 30 deletions

View file

@ -9,7 +9,8 @@ from typing import List, Optional, Union
import onnx
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.tools.convert_onnx_models_to_ort import OptimizationStyle, convert_onnx_models_to_ort
from onnxruntime.training import onnxblock
class LossType(Enum):
@ -58,7 +59,8 @@ def generate_artifacts(
optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated.
artifact_directory: The directory to save the generated artifacts.
If None, the current working directory is used.
prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used.
prefix (str): The prefix to be used for the generated artifacts. If not specified, no prefix is used.
ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False.
Raises:
RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
@ -124,6 +126,10 @@ def generate_artifacts(
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()
def _export_to_ort_format(model_path, output_dir, extra_options):
if extra_options.get("ort_format", False):
convert_onnx_models_to_ort(model_path, output_dir=output_dir, optimization_styles=[OptimizationStyle.Fixed])
if artifact_directory is None:
artifact_directory = pathlib.Path.cwd()
prefix = ""
@ -137,12 +143,14 @@ def generate_artifacts(
if os.path.exists(training_model_path):
logging.info("Training model path %s already exists. Overwriting.", training_model_path)
onnx.save(training_model, training_model_path)
_export_to_ort_format(training_model_path, artifact_directory, extra_options)
logging.info("Saved training model to %s", training_model_path)
eval_model_path = artifact_directory / f"{prefix}eval_model.onnx"
if os.path.exists(eval_model_path):
logging.info("Eval model path %s already exists. Overwriting.", eval_model_path)
onnx.save(eval_model, eval_model_path)
_export_to_ort_format(eval_model_path, artifact_directory, extra_options)
logging.info("Saved eval model to %s", eval_model_path)
checkpoint_path = artifact_directory / f"{prefix}checkpoint"
@ -173,4 +181,5 @@ def generate_artifacts(
optimizer_model_path = artifact_directory / f"{prefix}optimizer_model.onnx"
onnx.save(optim_model, optimizer_model_path)
_export_to_ort_format(optimizer_model_path, artifact_directory, extra_options)
logging.info("Saved optimizer model to %s", optimizer_model_path)

View file

@ -915,3 +915,26 @@ def test_label_encoder_composition():
all_nodes = [node.op_type for node in model.graph.node]
assert "LabelEncoder" in all_nodes
def test_save_ort_format():
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
_, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size)
with tempfile.TemporaryDirectory() as temp_dir:
artifacts.generate_artifacts(
base_model,
requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"],
loss=artifacts.LossType.CrossEntropyLoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
ort_format=True,
)
assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "training_model.ort"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.ort"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.ort"))

View file

@ -5,7 +5,17 @@
# This script is a stub that uses the model conversion script from the util subdirectory.
# We do it this way so we can use relative imports in that script, which makes it easy to include
# in the ORT python package (where it must use relative imports)
from util.convert_onnx_models_to_ort import convert_onnx_models_to_ort
from util.convert_onnx_models_to_ort import convert_onnx_models_to_ort, parse_args
if __name__ == "__main__":
convert_onnx_models_to_ort()
args = parse_args()
convert_onnx_models_to_ort(
args.model_path_or_dir,
output_dir=args.output_dir,
optimization_styles=args.optimization_style,
custom_op_library_path=args.custom_op_library,
target_platform=args.target_platform,
save_optimized_onnx_model=args.save_optimized_onnx_model,
allow_conversion_failures=args.allow_conversion_failures,
enable_type_reduction=args.enable_type_reduction,
)

View file

@ -2,13 +2,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from __future__ import annotations
import argparse
import contextlib
import enum
import os
import pathlib
import tempfile
import typing
import onnxruntime as ort
@ -32,7 +33,7 @@ def _optimization_suffix(optimization_level_str: str, optimization_style: Optimi
def _create_config_file_path(
model_path_or_dir: pathlib.Path,
output_dir: typing.Optional[pathlib.Path],
output_dir: pathlib.Path | None,
optimization_level_str: str,
optimization_style: OptimizationStyle,
enable_type_reduction: bool,
@ -57,7 +58,7 @@ def _create_session_options(
optimization_level: ort.GraphOptimizationLevel,
output_model_path: pathlib.Path,
custom_op_library: pathlib.Path,
session_options_config_entries: typing.Dict[str, str],
session_options_config_entries: dict[str, str],
):
so = ort.SessionOptions()
so.optimized_model_filepath = str(output_model_path)
@ -74,15 +75,15 @@ def _create_session_options(
def _convert(
model_path_or_dir: pathlib.Path,
output_dir: typing.Optional[pathlib.Path],
output_dir: pathlib.Path | None,
optimization_level_str: str,
optimization_style: OptimizationStyle,
custom_op_library: pathlib.Path,
create_optimized_onnx_model: bool,
allow_conversion_failures: bool,
target_platform: str,
session_options_config_entries: typing.Dict[str, str],
) -> typing.List[pathlib.Path]:
session_options_config_entries: dict[str, str],
) -> list[pathlib.Path]:
model_dir = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent
output_dir = output_dir or model_dir
@ -258,24 +259,33 @@ def parse_args():
"processed.",
)
return parser.parse_args()
parsed_args = parser.parse_args()
parsed_args.optimization_style = [OptimizationStyle[style_str] for style_str in parsed_args.optimization_style]
return parsed_args
def convert_onnx_models_to_ort():
args = parse_args()
def convert_onnx_models_to_ort(
model_path_or_dir: pathlib.Path,
output_dir: pathlib.Path | None = None,
optimization_styles: list[OptimizationStyle] | None = None,
custom_op_library_path: pathlib.Path | None = None,
target_platform: str | None = None,
save_optimized_onnx_model: bool = False,
allow_conversion_failures: bool = False,
enable_type_reduction: bool = False,
):
if output_dir is not None:
if not output_dir.is_dir():
output_dir.mkdir(parents=True)
output_dir = output_dir.resolve(strict=True)
output_dir = None
if args.output_dir is not None:
if not args.output_dir.is_dir():
args.output_dir.mkdir(parents=True)
output_dir = args.output_dir.resolve(strict=True)
optimization_styles = optimization_styles or []
optimization_styles = [OptimizationStyle[style_str] for style_str in args.optimization_style]
# setting optimization level is not expected to be needed by typical users, but it can be set with this
# environment variable
optimization_level_str = os.getenv("ORT_CONVERT_ONNX_MODELS_TO_ORT_OPTIMIZATION_LEVEL", "all")
model_path_or_dir = args.model_path_or_dir.resolve()
custom_op_library = args.custom_op_library.resolve() if args.custom_op_library else None
model_path_or_dir = model_path_or_dir.resolve()
custom_op_library = custom_op_library_path.resolve() if custom_op_library_path else None
if not model_path_or_dir.is_dir() and not model_path_or_dir.is_file():
raise FileNotFoundError(f"Model path '{model_path_or_dir}' is not a file or directory.")
@ -285,7 +295,7 @@ def convert_onnx_models_to_ort():
session_options_config_entries = {}
if args.target_platform == "arm":
if target_platform is not None and target_platform == "arm":
session_options_config_entries["session.qdqisint8allowed"] = "1"
else:
session_options_config_entries["session.qdqisint8allowed"] = "0"
@ -303,9 +313,9 @@ def convert_onnx_models_to_ort():
optimization_level_str=optimization_level_str,
optimization_style=optimization_style,
custom_op_library=custom_op_library,
create_optimized_onnx_model=args.save_optimized_onnx_model,
allow_conversion_failures=args.allow_conversion_failures,
target_platform=args.target_platform,
create_optimized_onnx_model=save_optimized_onnx_model,
allow_conversion_failures=allow_conversion_failures,
target_platform=target_platform,
session_options_config_entries=session_options_config_entries,
)
@ -335,8 +345,8 @@ def convert_onnx_models_to_ort():
optimization_style=OptimizationStyle.Fixed,
custom_op_library=custom_op_library,
create_optimized_onnx_model=False, # not useful as they would be created in a temp directory
allow_conversion_failures=args.allow_conversion_failures,
target_platform=args.target_platform,
allow_conversion_failures=allow_conversion_failures,
target_platform=target_platform,
session_options_config_entries=session_options_config_entries_for_second_conversion,
)
@ -351,11 +361,21 @@ def convert_onnx_models_to_ort():
output_dir,
optimization_level_str,
optimization_style,
args.enable_type_reduction,
enable_type_reduction,
)
create_config_from_models(converted_models, config_file, args.enable_type_reduction)
create_config_from_models(converted_models, config_file, enable_type_reduction)
if __name__ == "__main__":
convert_onnx_models_to_ort()
args = parse_args()
convert_onnx_models_to_ort(
args.model_path_or_dir,
output_dir=args.output_dir,
optimization_styles=args.optimization_style,
custom_op_library_path=args.custom_op_library,
target_platform=args.target_platform,
save_optimized_onnx_model=args.save_optimized_onnx_model,
allow_conversion_failures=args.allow_conversion_failures,
enable_type_reduction=args.enable_type_reduction,
)