onnxruntime/tools/python/offline_tuning.py
cloudhan a216c9a3fa
Offline tuning (#14558)
Add the ability to get and set tuning results of an inference session.
Also add tool to manipulate onnx file to embed the results into the
model file and automatically load it on session initialization.
2023-02-15 14:17:34 +08:00

169 lines
6 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import copy
import json
import sys
from collections import OrderedDict
from pprint import pprint
from typing import Any, Dict, List
import onnx
TuningResults = Dict[str, Any]
_TUNING_RESULTS_KEY = "tuning_results"
def _find_tuning_results_in_props(metadata_props):
for idx, prop in enumerate(metadata_props):
if prop.key == _TUNING_RESULTS_KEY:
return idx
return -1
def extract(model: onnx.ModelProto):
idx = _find_tuning_results_in_props(model.metadata_props)
if idx < 0:
return None
tuning_results_prop = model.metadata_props[idx]
return json.loads(tuning_results_prop.value)
def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False):
idx = _find_tuning_results_in_props(model.metadata_props)
assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!"
if idx >= 0:
model.metadata_props.pop(idx)
entry = model.metadata_props.add()
entry.key = _TUNING_RESULTS_KEY
entry.value = json.dumps(tuning_results)
return model
class Merger:
class EpAndValidators:
def __init__(self, ep: str, validators: Dict[str, str]):
self.ep = ep
self.validators = copy.deepcopy(validators)
self.key = (ep, tuple(sorted(validators.items())))
def __hash__(self):
return hash(self.key)
def __eq__(self, other):
return self.ep == other.ep and self.key == other.key
def __init__(self):
self.ev_to_results = OrderedDict()
def merge(self, tuning_results: List[TuningResults]):
for trs in tuning_results:
self._merge_one(trs)
def get_merged(self):
tuning_results = []
for ev, flat_results in self.ev_to_results.items():
results = {}
trs = {
"ep": ev.ep,
"validators": ev.validators,
"results": results,
}
for (op_sig, params_sig), kernel_id in flat_results.items():
kernel_map = results.setdefault(op_sig, {})
kernel_map[params_sig] = kernel_id
tuning_results.append(trs)
return tuning_results
def _merge_one(self, trs: TuningResults):
ev = Merger.EpAndValidators(trs["ep"], trs["validators"])
flat_results = self.ev_to_results.setdefault(ev, {})
for op_sig, kernel_map in trs["results"].items():
for params_sig, kernel_id in kernel_map.items():
if (op_sig, params_sig) not in flat_results:
flat_results[(op_sig, params_sig)] = kernel_id
def parse_args():
parser = argparse.ArgumentParser()
sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd")
extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.")
extract_parser.add_argument("input_onnx")
extract_parser.add_argument("output_json")
embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.")
embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.")
embed_parser.add_argument("output_onnx", help="Path of the output onnx file.")
embed_parser.add_argument("input_onnx", help="Path of the input onnx file.")
embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.")
merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.")
merge_parser.add_argument("output_json", help="Path of the output tuning results file.")
merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.")
pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.")
pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.")
args = parser.parse_args()
if len(vars(args)) == 0:
parser.print_help()
exit(-1)
return args
def main():
args = parse_args()
if args.cmd == "extract":
tuning_results = extract(onnx.load_model(args.input_onnx))
if tuning_results is None:
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
sys.exit(-1)
json.dump(tuning_results, open(args.output_json, "w"))
elif args.cmd == "embed":
model = onnx.load_model(args.input_onnx)
merger = Merger()
for tuning_results in [json.load(open(f)) for f in args.input_json]:
merger.merge(tuning_results)
model = embed(model, merger.get_merged(), args.force)
onnx.save_model(model, args.output_onnx)
elif args.cmd == "merge":
merger = Merger()
for tuning_results in [json.load(open(f)) for f in args.input_json]:
merger.merge(tuning_results)
json.dump(merger.get_merged(), open(args.output_json, "w"))
elif args.cmd == "pprint":
tuning_results = None
try:
tuning_results = json.load(open(args.json_or_onnx, "r"))
except Exception:
# it might be an onnx file otherwise, try it latter
pass
if tuning_results is None:
try:
model = onnx.load_model(args.json_or_onnx)
tuning_results = extract(model)
if tuning_results is None:
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
sys.exit(-1)
except Exception:
pass
if tuning_results is None:
sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!")
sys.exit(-1)
pprint(tuning_results)
else:
# invalid choice will be handled by the parser
pass
if __name__ == "__main__":
main()