#!/usr/bin/env python # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import argparse import io # noqa: F401 import os import pathlib from collections import defaultdict import onnxruntime.capi.onnxruntime_pybind11_state as rtpy def format_version_range(v): if v[1] >= 2147483647: return str(v[0]) + "+" else: if v[0] == v[1]: return str(v[0]) else: return "[" + str(v[0]) + ", " + str(v[1]) + "]" def format_type_constraints(tc): counter = 0 tcstr = "" firsttcitem = True for tcitem in tc: counter += 1 if firsttcitem: firsttcitem = False else: tcstr += ", " tcstr += tcitem return tcstr def format_param_strings(params): firstparam = True s = "" if params: for param in sorted(params): if firstparam: firstparam = False else: s += "

or

" s += param return s def expand_providers(provider_filter: [str]): providers = set() if provider_filter: for provider in provider_filter: p = provider.lower() if not p.endswith("executionprovider"): p += "executionprovider" providers.add(p) return providers def main(output_path: pathlib.Path, provider_filter: [str]): providers = expand_providers(provider_filter) with open(output_path, "w", newline="", encoding="utf-8") as fout: fout.write("## Supported Operators and Data Types\n") fout.write( "*This file is automatically generated from the registered kernels by " "[this script](https://github.com/microsoft/onnxruntime/blob/main/tools/python/gen_opkernel_doc.py).\n" "Do not modify directly.*\n\n" ) opdef = rtpy.get_all_operator_schema() paramdict = {} for schema in opdef: inputs = schema.inputs domain = schema.domain if not domain: domain = "ai.onnx" fullname = domain + "." + schema.name paramstr = "" firstinput = True if inputs: for inp in inputs: if firstinput: firstinput = False else: paramstr += "
" paramstr += f"*in* {inp.name}:**{inp.typeStr}**" outputs = schema.outputs if outputs: for outp in outputs: if firstinput: firstinput = False else: paramstr += "
" paramstr += f"*out* {outp.name}:**{outp.typeStr}**" paramstr += "" paramset = paramdict.get(fullname, None) if paramset is None: paramdict[fullname] = set() paramdict[fullname].add(paramstr) index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) for op in rtpy.get_all_opkernel_def(): domain = op.domain if not domain: domain = "ai.onnx" index[op.provider][domain][op.op_name].append(op) # TOC fout.write("## Execution Providers\n\n") for provider in sorted(index.keys()): if providers and provider.lower() not in providers: continue fout.write(f"- [{provider}](#{provider.lower()})\n") fout.write("\n---------------") for provider, domainmap in sorted(index.items()): if providers and provider.lower() not in providers: continue fout.write(f'\n\n\n\n') fout.write(f"## Operators implemented by {provider}\n\n") fout.write("| Op Name | Parameters | OpSet Version | Types Supported |\n") fout.write("|---------|------------|---------------|-----------------|\n") for domain, namemap in sorted(domainmap.items()): fout.write("|**Operator Domain:** *" + domain + "*||||\n") for name, ops in sorted(namemap.items()): version_type_index = defaultdict(lambda: defaultdict(set)) for op in ops: for tname, tclist in op.type_constraints.items(): for c in tclist: version_type_index[op.version_range][tname].add(c) namefirsttime = True for version_range, typemap in sorted(version_type_index.items(), key=lambda x: x[0], reverse=True): if namefirsttime: params = paramdict.get(domain + "." + name, None) fout.write("|" + name + "|" + format_param_strings(params) + "|") namefirsttime = False else: fout.write("|||") fout.write(format_version_range(version_range) + "|") tnameindex = 0 for tname, tcset in sorted(typemap.items()): tnameindex += 1 tclist = [] for tc in sorted(tcset): tclist.append(tc) fout.write("**" + tname + "** = " + format_type_constraints(tclist)) if tnameindex < len(typemap): fout.write("
") fout.write("|\n") fout.write("| |\n| |\n") if __name__ == "__main__": parser = argparse.ArgumentParser(description="ONNX Runtime Operator Kernel Documentation Generator") parser.add_argument( "--providers", nargs="+", help="Filter to specified execution providers. Case-insensitive. " "Matches provider names from /include/onnxruntime/core/graph/constants.h'. " "'ExecutionProvider' is automatically appended as needed. " "e.g. `--providers cpu cuda` will match CPUExecutionProvider and CUDAExecutionProvider.", ) parser.add_argument( "--output_path", help="output markdown file path", type=pathlib.Path, required=True, default=os.path.join(os.path.dirname(os.path.realpath(__file__)), "OperatorKernels.md"), ) args = parser.parse_args() main(args.output_path, args.providers)