mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Add the ability to get and set tuning results of an inference session. Also add tool to manipulate onnx file to embed the results into the model file and automatically load it on session initialization.
169 lines
6 KiB
Python
169 lines
6 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import argparse
|
|
import copy
|
|
import json
|
|
import sys
|
|
from collections import OrderedDict
|
|
from pprint import pprint
|
|
from typing import Any, Dict, List
|
|
|
|
import onnx
|
|
|
|
TuningResults = Dict[str, Any]
|
|
|
|
_TUNING_RESULTS_KEY = "tuning_results"
|
|
|
|
|
|
def _find_tuning_results_in_props(metadata_props):
|
|
for idx, prop in enumerate(metadata_props):
|
|
if prop.key == _TUNING_RESULTS_KEY:
|
|
return idx
|
|
return -1
|
|
|
|
|
|
def extract(model: onnx.ModelProto):
|
|
idx = _find_tuning_results_in_props(model.metadata_props)
|
|
if idx < 0:
|
|
return None
|
|
|
|
tuning_results_prop = model.metadata_props[idx]
|
|
return json.loads(tuning_results_prop.value)
|
|
|
|
|
|
def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False):
|
|
idx = _find_tuning_results_in_props(model.metadata_props)
|
|
assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!"
|
|
|
|
if idx >= 0:
|
|
model.metadata_props.pop(idx)
|
|
|
|
entry = model.metadata_props.add()
|
|
entry.key = _TUNING_RESULTS_KEY
|
|
entry.value = json.dumps(tuning_results)
|
|
return model
|
|
|
|
|
|
class Merger:
|
|
class EpAndValidators:
|
|
def __init__(self, ep: str, validators: Dict[str, str]):
|
|
self.ep = ep
|
|
self.validators = copy.deepcopy(validators)
|
|
self.key = (ep, tuple(sorted(validators.items())))
|
|
|
|
def __hash__(self):
|
|
return hash(self.key)
|
|
|
|
def __eq__(self, other):
|
|
return self.ep == other.ep and self.key == other.key
|
|
|
|
def __init__(self):
|
|
self.ev_to_results = OrderedDict()
|
|
|
|
def merge(self, tuning_results: List[TuningResults]):
|
|
for trs in tuning_results:
|
|
self._merge_one(trs)
|
|
|
|
def get_merged(self):
|
|
tuning_results = []
|
|
for ev, flat_results in self.ev_to_results.items():
|
|
results = {}
|
|
trs = {
|
|
"ep": ev.ep,
|
|
"validators": ev.validators,
|
|
"results": results,
|
|
}
|
|
for (op_sig, params_sig), kernel_id in flat_results.items():
|
|
kernel_map = results.setdefault(op_sig, {})
|
|
kernel_map[params_sig] = kernel_id
|
|
tuning_results.append(trs)
|
|
return tuning_results
|
|
|
|
def _merge_one(self, trs: TuningResults):
|
|
ev = Merger.EpAndValidators(trs["ep"], trs["validators"])
|
|
flat_results = self.ev_to_results.setdefault(ev, {})
|
|
for op_sig, kernel_map in trs["results"].items():
|
|
for params_sig, kernel_id in kernel_map.items():
|
|
if (op_sig, params_sig) not in flat_results:
|
|
flat_results[(op_sig, params_sig)] = kernel_id
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd")
|
|
|
|
extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.")
|
|
extract_parser.add_argument("input_onnx")
|
|
extract_parser.add_argument("output_json")
|
|
|
|
embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.")
|
|
embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.")
|
|
embed_parser.add_argument("output_onnx", help="Path of the output onnx file.")
|
|
embed_parser.add_argument("input_onnx", help="Path of the input onnx file.")
|
|
embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.")
|
|
|
|
merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.")
|
|
merge_parser.add_argument("output_json", help="Path of the output tuning results file.")
|
|
merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.")
|
|
|
|
pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.")
|
|
pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.")
|
|
|
|
args = parser.parse_args()
|
|
if len(vars(args)) == 0:
|
|
parser.print_help()
|
|
exit(-1)
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
if args.cmd == "extract":
|
|
tuning_results = extract(onnx.load_model(args.input_onnx))
|
|
if tuning_results is None:
|
|
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
|
|
sys.exit(-1)
|
|
json.dump(tuning_results, open(args.output_json, "w"))
|
|
elif args.cmd == "embed":
|
|
model = onnx.load_model(args.input_onnx)
|
|
merger = Merger()
|
|
for tuning_results in [json.load(open(f)) for f in args.input_json]:
|
|
merger.merge(tuning_results)
|
|
model = embed(model, merger.get_merged(), args.force)
|
|
onnx.save_model(model, args.output_onnx)
|
|
elif args.cmd == "merge":
|
|
merger = Merger()
|
|
for tuning_results in [json.load(open(f)) for f in args.input_json]:
|
|
merger.merge(tuning_results)
|
|
json.dump(merger.get_merged(), open(args.output_json, "w"))
|
|
elif args.cmd == "pprint":
|
|
tuning_results = None
|
|
try:
|
|
tuning_results = json.load(open(args.json_or_onnx, "r"))
|
|
except Exception:
|
|
# it might be an onnx file otherwise, try it latter
|
|
pass
|
|
|
|
if tuning_results is None:
|
|
try:
|
|
model = onnx.load_model(args.json_or_onnx)
|
|
tuning_results = extract(model)
|
|
if tuning_results is None:
|
|
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
|
|
sys.exit(-1)
|
|
except Exception:
|
|
pass
|
|
|
|
if tuning_results is None:
|
|
sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!")
|
|
sys.exit(-1)
|
|
|
|
pprint(tuning_results)
|
|
else:
|
|
# invalid choice will be handled by the parser
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|