#!/usr/bin/env python # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. from collections import defaultdict import io import os import argparse from typing import Type import onnxruntime.capi.onnxruntime_pybind11_state as rtpy def format_version_range(v): if (v[1] >= 2147483647): return str(v[0])+'+' else: if (v[0] == v[1]): return str(v[0]) else: return '['+str(v[0])+', '+str(v[1])+']' def format_type_constraints(tc): counter = 0 tcstr = '' firsttcitem = True for tcitem in tc: counter += 1 if firsttcitem: firsttcitem = False else: tcstr += ', ' tcstr += tcitem return tcstr def format_param_strings(params): firstparam = True s = '' if params: for param in sorted(params): if firstparam: firstparam = False else: s += ' or ' s += param return s def main(args): # type: (Type[Args]) -> None with io.open(args.output, 'w', newline='', encoding="utf-8") as fout: fout.write('## Supported Operators Data Types\n') fout.write( "*This file is automatically generated from the\n" " [def files](/onnxruntime/core/providers/cpu/cpu_execution_provider.cc) via " "[this script](/tools/python/gen_opkernel_doc.py).\n" " Do not modify directly and instead edit operator definitions.*\n") opdef = rtpy.get_all_operator_schema() paramdict = {} for schema in opdef: inputs = schema.inputs domain = schema.domain if (domain == ''): domain = 'ai.onnx.ml' fullname = domain+'.'+schema.name paramstr = '(' firstinput = True if inputs: for inp in inputs: if firstinput: firstinput = False else: paramstr += ', ' paramstr += '*in* {}:**{}**'.format(inp.name, inp.typeStr) outputs = schema.outputs if outputs: for outp in outputs: if firstinput: firstinput = False else: paramstr += ', ' paramstr += '*out* {}:**{}**'.format(outp.name, outp.typeStr) paramstr += ')' paramset = paramdict.get(fullname, None) if paramset is None: paramdict[fullname] = set() paramdict[fullname].add(paramstr) index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) for op in rtpy.get_all_opkernel_def(): domain = op.domain if (domain == ''): domain = 'ai.onnx.ml' index[op.provider][domain][op.op_name].append(op) fout.write('\n') for provider, domainmap in sorted(index.items()): fout.write('\n\n## Operators implemented by '+provider+'\n\n') fout.write('| Op Name | Parameters | OpSet Version | Types Supported |\n') fout.write('|---------|------------|---------------|-----------------|\n') for domain, namemap in sorted(domainmap.items()): fout.write('|**Operator Domain:** *'+domain+'*||||\n') for name, ops in sorted(namemap.items()): version_type_index = defaultdict(lambda: defaultdict(set)) for op in ops: for tname, tclist in op.type_constraints.items(): for c in tclist: version_type_index[op.version_range][tname].add(c) namefirsttime = True for version_range, typemap in sorted(version_type_index.items(), key=lambda x: x[0], reverse=True): if (namefirsttime): params = paramdict.get(domain+'.'+name, None) fout.write('|' + name + '|' + format_param_strings(params) + '|') namefirsttime = False else: fout.write('|||') fout.write(format_version_range(version_range) + '|') tnameindex = 0 for tname, tcset in sorted(typemap.items()): tnameindex += 1 tclist = [] for tc in sorted(tcset): tclist.append(tc) fout.write('**'+tname+'** = '+format_type_constraints(tclist)) if (tnameindex < len(typemap)): fout.write('
') fout.write('|\n') fout.write('| |\n| |\n') if __name__ == '__main__': parser = argparse.ArgumentParser(description='ONNX Runtime Operator Kernel Documentation Generator') parser.add_argument('--output_path', help='output markdown file path', default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'OperatorKernels.md') ) args = parser.parse_args() class Args(object): output = args.output_path main(Args)