From 42489a8a241e0cefc0efaaf07d192b7f840cc5f9 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 21 Jun 2023 18:51:43 -0700 Subject: [PATCH] Add ability to create ort format models from training offline utility (#16360) --- .../orttraining/python/training/artifacts.py | 13 +++- .../test/python/orttraining_test_onnxblock.py | 23 ++++++ tools/python/convert_onnx_models_to_ort.py | 14 +++- .../python/util/convert_onnx_models_to_ort.py | 72 ++++++++++++------- 4 files changed, 92 insertions(+), 30 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 60d5d92e15..bb584321eb 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -9,7 +9,8 @@ from typing import List, Optional, Union import onnx -import onnxruntime.training.onnxblock as onnxblock +from onnxruntime.tools.convert_onnx_models_to_ort import OptimizationStyle, convert_onnx_models_to_ort +from onnxruntime.training import onnxblock class LossType(Enum): @@ -58,7 +59,8 @@ def generate_artifacts( optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated. artifact_directory: The directory to save the generated artifacts. If None, the current working directory is used. - prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used. + prefix (str): The prefix to be used for the generated artifacts. If not specified, no prefix is used. + ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False. Raises: RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` @@ -124,6 +126,10 @@ def generate_artifacts( training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() + def _export_to_ort_format(model_path, output_dir, extra_options): + if extra_options.get("ort_format", False): + convert_onnx_models_to_ort(model_path, output_dir=output_dir, optimization_styles=[OptimizationStyle.Fixed]) + if artifact_directory is None: artifact_directory = pathlib.Path.cwd() prefix = "" @@ -137,12 +143,14 @@ def generate_artifacts( if os.path.exists(training_model_path): logging.info("Training model path %s already exists. Overwriting.", training_model_path) onnx.save(training_model, training_model_path) + _export_to_ort_format(training_model_path, artifact_directory, extra_options) logging.info("Saved training model to %s", training_model_path) eval_model_path = artifact_directory / f"{prefix}eval_model.onnx" if os.path.exists(eval_model_path): logging.info("Eval model path %s already exists. Overwriting.", eval_model_path) onnx.save(eval_model, eval_model_path) + _export_to_ort_format(eval_model_path, artifact_directory, extra_options) logging.info("Saved eval model to %s", eval_model_path) checkpoint_path = artifact_directory / f"{prefix}checkpoint" @@ -173,4 +181,5 @@ def generate_artifacts( optimizer_model_path = artifact_directory / f"{prefix}optimizer_model.onnx" onnx.save(optim_model, optimizer_model_path) + _export_to_ort_format(optimizer_model_path, artifact_directory, extra_options) logging.info("Saved optimizer model to %s", optimizer_model_path) diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py index ce919f4706..2d19474755 100644 --- a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py @@ -915,3 +915,26 @@ def test_label_encoder_composition(): all_nodes = [node.op_type for node in model.graph.node] assert "LabelEncoder" in all_nodes + + +def test_save_ort_format(): + device = "cpu" + batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 + _, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size) + + with tempfile.TemporaryDirectory() as temp_dir: + artifacts.generate_artifacts( + base_model, + requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"], + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ort_format=True, + ) + + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "training_model.ort")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.ort")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.ort")) diff --git a/tools/python/convert_onnx_models_to_ort.py b/tools/python/convert_onnx_models_to_ort.py index 131cbcbe70..ac137feb0c 100644 --- a/tools/python/convert_onnx_models_to_ort.py +++ b/tools/python/convert_onnx_models_to_ort.py @@ -5,7 +5,17 @@ # This script is a stub that uses the model conversion script from the util subdirectory. # We do it this way so we can use relative imports in that script, which makes it easy to include # in the ORT python package (where it must use relative imports) -from util.convert_onnx_models_to_ort import convert_onnx_models_to_ort +from util.convert_onnx_models_to_ort import convert_onnx_models_to_ort, parse_args if __name__ == "__main__": - convert_onnx_models_to_ort() + args = parse_args() + convert_onnx_models_to_ort( + args.model_path_or_dir, + output_dir=args.output_dir, + optimization_styles=args.optimization_style, + custom_op_library_path=args.custom_op_library, + target_platform=args.target_platform, + save_optimized_onnx_model=args.save_optimized_onnx_model, + allow_conversion_failures=args.allow_conversion_failures, + enable_type_reduction=args.enable_type_reduction, + ) diff --git a/tools/python/util/convert_onnx_models_to_ort.py b/tools/python/util/convert_onnx_models_to_ort.py index 2db974643c..18bba78661 100644 --- a/tools/python/util/convert_onnx_models_to_ort.py +++ b/tools/python/util/convert_onnx_models_to_ort.py @@ -2,13 +2,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations + import argparse import contextlib import enum import os import pathlib import tempfile -import typing import onnxruntime as ort @@ -32,7 +33,7 @@ def _optimization_suffix(optimization_level_str: str, optimization_style: Optimi def _create_config_file_path( model_path_or_dir: pathlib.Path, - output_dir: typing.Optional[pathlib.Path], + output_dir: pathlib.Path | None, optimization_level_str: str, optimization_style: OptimizationStyle, enable_type_reduction: bool, @@ -57,7 +58,7 @@ def _create_session_options( optimization_level: ort.GraphOptimizationLevel, output_model_path: pathlib.Path, custom_op_library: pathlib.Path, - session_options_config_entries: typing.Dict[str, str], + session_options_config_entries: dict[str, str], ): so = ort.SessionOptions() so.optimized_model_filepath = str(output_model_path) @@ -74,15 +75,15 @@ def _create_session_options( def _convert( model_path_or_dir: pathlib.Path, - output_dir: typing.Optional[pathlib.Path], + output_dir: pathlib.Path | None, optimization_level_str: str, optimization_style: OptimizationStyle, custom_op_library: pathlib.Path, create_optimized_onnx_model: bool, allow_conversion_failures: bool, target_platform: str, - session_options_config_entries: typing.Dict[str, str], -) -> typing.List[pathlib.Path]: + session_options_config_entries: dict[str, str], +) -> list[pathlib.Path]: model_dir = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent output_dir = output_dir or model_dir @@ -258,24 +259,33 @@ def parse_args(): "processed.", ) - return parser.parse_args() + parsed_args = parser.parse_args() + parsed_args.optimization_style = [OptimizationStyle[style_str] for style_str in parsed_args.optimization_style] + return parsed_args -def convert_onnx_models_to_ort(): - args = parse_args() +def convert_onnx_models_to_ort( + model_path_or_dir: pathlib.Path, + output_dir: pathlib.Path | None = None, + optimization_styles: list[OptimizationStyle] | None = None, + custom_op_library_path: pathlib.Path | None = None, + target_platform: str | None = None, + save_optimized_onnx_model: bool = False, + allow_conversion_failures: bool = False, + enable_type_reduction: bool = False, +): + if output_dir is not None: + if not output_dir.is_dir(): + output_dir.mkdir(parents=True) + output_dir = output_dir.resolve(strict=True) - output_dir = None - if args.output_dir is not None: - if not args.output_dir.is_dir(): - args.output_dir.mkdir(parents=True) - output_dir = args.output_dir.resolve(strict=True) + optimization_styles = optimization_styles or [] - optimization_styles = [OptimizationStyle[style_str] for style_str in args.optimization_style] # setting optimization level is not expected to be needed by typical users, but it can be set with this # environment variable optimization_level_str = os.getenv("ORT_CONVERT_ONNX_MODELS_TO_ORT_OPTIMIZATION_LEVEL", "all") - model_path_or_dir = args.model_path_or_dir.resolve() - custom_op_library = args.custom_op_library.resolve() if args.custom_op_library else None + model_path_or_dir = model_path_or_dir.resolve() + custom_op_library = custom_op_library_path.resolve() if custom_op_library_path else None if not model_path_or_dir.is_dir() and not model_path_or_dir.is_file(): raise FileNotFoundError(f"Model path '{model_path_or_dir}' is not a file or directory.") @@ -285,7 +295,7 @@ def convert_onnx_models_to_ort(): session_options_config_entries = {} - if args.target_platform == "arm": + if target_platform is not None and target_platform == "arm": session_options_config_entries["session.qdqisint8allowed"] = "1" else: session_options_config_entries["session.qdqisint8allowed"] = "0" @@ -303,9 +313,9 @@ def convert_onnx_models_to_ort(): optimization_level_str=optimization_level_str, optimization_style=optimization_style, custom_op_library=custom_op_library, - create_optimized_onnx_model=args.save_optimized_onnx_model, - allow_conversion_failures=args.allow_conversion_failures, - target_platform=args.target_platform, + create_optimized_onnx_model=save_optimized_onnx_model, + allow_conversion_failures=allow_conversion_failures, + target_platform=target_platform, session_options_config_entries=session_options_config_entries, ) @@ -335,8 +345,8 @@ def convert_onnx_models_to_ort(): optimization_style=OptimizationStyle.Fixed, custom_op_library=custom_op_library, create_optimized_onnx_model=False, # not useful as they would be created in a temp directory - allow_conversion_failures=args.allow_conversion_failures, - target_platform=args.target_platform, + allow_conversion_failures=allow_conversion_failures, + target_platform=target_platform, session_options_config_entries=session_options_config_entries_for_second_conversion, ) @@ -351,11 +361,21 @@ def convert_onnx_models_to_ort(): output_dir, optimization_level_str, optimization_style, - args.enable_type_reduction, + enable_type_reduction, ) - create_config_from_models(converted_models, config_file, args.enable_type_reduction) + create_config_from_models(converted_models, config_file, enable_type_reduction) if __name__ == "__main__": - convert_onnx_models_to_ort() + args = parse_args() + convert_onnx_models_to_ort( + args.model_path_or_dir, + output_dir=args.output_dir, + optimization_styles=args.optimization_style, + custom_op_library_path=args.custom_op_library, + target_platform=args.target_platform, + save_optimized_onnx_model=args.save_optimized_onnx_model, + allow_conversion_failures=args.allow_conversion_failures, + enable_type_reduction=args.enable_type_reduction, + )