2019-08-15 01:12:24 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
2021-06-02 07:47:40 +00:00
|
|
|
import argparse
|
2019-08-15 01:12:24 +00:00
|
|
|
import io
|
|
|
|
|
import os
|
2021-06-02 07:47:40 +00:00
|
|
|
import pathlib
|
|
|
|
|
from collections import defaultdict
|
2019-08-15 01:12:24 +00:00
|
|
|
|
2020-05-14 21:15:06 +00:00
|
|
|
import onnxruntime.capi.onnxruntime_pybind11_state as rtpy
|
2019-08-15 01:12:24 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_version_range(v):
|
|
|
|
|
if (v[1] >= 2147483647):
|
|
|
|
|
return str(v[0])+'+'
|
|
|
|
|
else:
|
2020-09-02 22:07:50 +00:00
|
|
|
if (v[0] == v[1]):
|
|
|
|
|
return str(v[0])
|
|
|
|
|
else:
|
|
|
|
|
return '['+str(v[0])+', '+str(v[1])+']'
|
2020-05-14 21:15:06 +00:00
|
|
|
|
2019-08-15 01:12:24 +00:00
|
|
|
|
|
|
|
|
def format_type_constraints(tc):
|
|
|
|
|
counter = 0
|
|
|
|
|
tcstr = ''
|
|
|
|
|
firsttcitem = True
|
|
|
|
|
for tcitem in tc:
|
|
|
|
|
counter += 1
|
|
|
|
|
if firsttcitem:
|
|
|
|
|
firsttcitem = False
|
|
|
|
|
else:
|
|
|
|
|
tcstr += ', '
|
|
|
|
|
tcstr += tcitem
|
|
|
|
|
return tcstr
|
|
|
|
|
|
2020-05-14 21:15:06 +00:00
|
|
|
|
2019-08-15 01:12:24 +00:00
|
|
|
def format_param_strings(params):
|
|
|
|
|
firstparam = True
|
|
|
|
|
s = ''
|
|
|
|
|
if params:
|
2020-04-17 21:41:04 +00:00
|
|
|
for param in sorted(params):
|
2019-08-15 01:12:24 +00:00
|
|
|
if firstparam:
|
|
|
|
|
firstparam = False
|
|
|
|
|
else:
|
2021-06-02 07:47:40 +00:00
|
|
|
s += '<br><br>or<br><br>'
|
2019-08-15 01:12:24 +00:00
|
|
|
s += param
|
|
|
|
|
return s
|
2020-05-14 21:15:06 +00:00
|
|
|
|
|
|
|
|
|
2021-06-02 07:47:40 +00:00
|
|
|
def expand_providers(provider_filter: [str]):
|
|
|
|
|
providers = set()
|
|
|
|
|
if provider_filter:
|
|
|
|
|
for provider in provider_filter:
|
|
|
|
|
p = provider.lower()
|
|
|
|
|
if not p.endswith('executionprovider'):
|
|
|
|
|
p += 'executionprovider'
|
|
|
|
|
providers.add(p)
|
|
|
|
|
|
|
|
|
|
return providers
|
2020-05-14 21:15:06 +00:00
|
|
|
|
2021-06-02 07:47:40 +00:00
|
|
|
|
|
|
|
|
def main(output_path: pathlib.Path, provider_filter: [str]):
|
|
|
|
|
|
|
|
|
|
providers = expand_providers(provider_filter)
|
|
|
|
|
|
|
|
|
|
with io.open(output_path, 'w', newline='', encoding="utf-8") as fout:
|
|
|
|
|
fout.write('## Supported Operators and Data Types\n')
|
2019-08-15 01:12:24 +00:00
|
|
|
fout.write(
|
2021-06-02 07:47:40 +00:00
|
|
|
"*This file is automatically generated from the registered kernels by "
|
|
|
|
|
"[this script](https://github.com/microsoft/onnxruntime/blob/master/tools/python/gen_opkernel_doc.py).\n"
|
|
|
|
|
"Do not modify directly.*\n\n")
|
2019-08-15 01:12:24 +00:00
|
|
|
opdef = rtpy.get_all_operator_schema()
|
|
|
|
|
paramdict = {}
|
|
|
|
|
for schema in opdef:
|
|
|
|
|
inputs = schema.inputs
|
|
|
|
|
domain = schema.domain
|
|
|
|
|
if (domain == ''):
|
2021-06-02 07:47:40 +00:00
|
|
|
domain = 'ai.onnx'
|
2019-08-15 01:12:24 +00:00
|
|
|
fullname = domain+'.'+schema.name
|
2021-06-02 07:47:40 +00:00
|
|
|
paramstr = ''
|
2019-08-15 01:12:24 +00:00
|
|
|
firstinput = True
|
|
|
|
|
if inputs:
|
|
|
|
|
for inp in inputs:
|
|
|
|
|
if firstinput:
|
|
|
|
|
firstinput = False
|
|
|
|
|
else:
|
2021-06-02 07:47:40 +00:00
|
|
|
paramstr += '<br> '
|
2019-08-15 01:12:24 +00:00
|
|
|
paramstr += '*in* {}:**{}**'.format(inp.name, inp.typeStr)
|
|
|
|
|
|
|
|
|
|
outputs = schema.outputs
|
|
|
|
|
if outputs:
|
|
|
|
|
for outp in outputs:
|
|
|
|
|
if firstinput:
|
|
|
|
|
firstinput = False
|
|
|
|
|
else:
|
2021-06-02 07:47:40 +00:00
|
|
|
paramstr += '<br> '
|
2019-08-15 01:12:24 +00:00
|
|
|
paramstr += '*out* {}:**{}**'.format(outp.name, outp.typeStr)
|
|
|
|
|
|
2021-06-02 07:47:40 +00:00
|
|
|
paramstr += ''
|
2020-05-14 21:15:06 +00:00
|
|
|
paramset = paramdict.get(fullname, None)
|
|
|
|
|
if paramset is None:
|
2019-08-15 01:12:24 +00:00
|
|
|
paramdict[fullname] = set()
|
2020-05-14 21:15:06 +00:00
|
|
|
|
2019-08-15 01:12:24 +00:00
|
|
|
paramdict[fullname].add(paramstr)
|
|
|
|
|
|
2020-05-14 21:15:06 +00:00
|
|
|
index = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
2019-08-15 01:12:24 +00:00
|
|
|
for op in rtpy.get_all_opkernel_def():
|
|
|
|
|
domain = op.domain
|
|
|
|
|
if (domain == ''):
|
2021-06-02 07:47:40 +00:00
|
|
|
domain = 'ai.onnx'
|
2019-08-15 01:12:24 +00:00
|
|
|
index[op.provider][domain][op.op_name].append(op)
|
|
|
|
|
|
2021-06-02 07:47:40 +00:00
|
|
|
# TOC
|
|
|
|
|
fout.write('## Execution Providers\n\n')
|
|
|
|
|
for provider in sorted(index.keys()):
|
|
|
|
|
if providers and provider.lower() not in providers:
|
|
|
|
|
continue
|
|
|
|
|
fout.write('- [{}](#{})\n'.format(provider, provider.lower()))
|
|
|
|
|
fout.write('\n---------------')
|
|
|
|
|
|
2019-08-15 01:12:24 +00:00
|
|
|
for provider, domainmap in sorted(index.items()):
|
2021-06-02 07:47:40 +00:00
|
|
|
if providers and provider.lower() not in providers:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
fout.write('\n\n<a name="{}"/>\n\n'.format(provider.lower()))
|
|
|
|
|
fout.write('## Operators implemented by {}\n\n'.format(provider))
|
2019-08-15 01:12:24 +00:00
|
|
|
fout.write('| Op Name | Parameters | OpSet Version | Types Supported |\n')
|
|
|
|
|
fout.write('|---------|------------|---------------|-----------------|\n')
|
|
|
|
|
for domain, namemap in sorted(domainmap.items()):
|
2020-09-02 22:07:50 +00:00
|
|
|
fout.write('|**Operator Domain:** *'+domain+'*||||\n')
|
2019-08-15 01:12:24 +00:00
|
|
|
for name, ops in sorted(namemap.items()):
|
|
|
|
|
version_type_index = defaultdict(lambda: defaultdict(set))
|
2020-05-14 21:15:06 +00:00
|
|
|
for op in ops:
|
|
|
|
|
for tname, tclist in op.type_constraints.items():
|
2019-08-15 01:12:24 +00:00
|
|
|
for c in tclist:
|
2020-09-02 22:07:50 +00:00
|
|
|
version_type_index[op.version_range][tname].add(c)
|
2019-08-15 01:12:24 +00:00
|
|
|
|
|
|
|
|
namefirsttime = True
|
2020-09-02 22:07:50 +00:00
|
|
|
for version_range, typemap in sorted(version_type_index.items(), key=lambda x: x[0], reverse=True):
|
|
|
|
|
if (namefirsttime):
|
|
|
|
|
params = paramdict.get(domain+'.'+name, None)
|
|
|
|
|
fout.write('|' + name + '|' + format_param_strings(params) + '|')
|
|
|
|
|
namefirsttime = False
|
|
|
|
|
else:
|
|
|
|
|
fout.write('|||')
|
|
|
|
|
fout.write(format_version_range(version_range) + '|')
|
|
|
|
|
tnameindex = 0
|
2019-08-15 01:12:24 +00:00
|
|
|
for tname, tcset in sorted(typemap.items()):
|
2020-09-02 22:07:50 +00:00
|
|
|
tnameindex += 1
|
2019-08-15 01:12:24 +00:00
|
|
|
tclist = []
|
2020-04-17 21:41:04 +00:00
|
|
|
for tc in sorted(tcset):
|
2019-08-15 01:12:24 +00:00
|
|
|
tclist.append(tc)
|
2020-09-02 22:07:50 +00:00
|
|
|
fout.write('**'+tname+'** = '+format_type_constraints(tclist))
|
|
|
|
|
if (tnameindex < len(typemap)):
|
|
|
|
|
fout.write('<br/> ')
|
|
|
|
|
fout.write('|\n')
|
2020-05-14 21:15:06 +00:00
|
|
|
|
2019-08-15 01:12:24 +00:00
|
|
|
fout.write('| |\n| |\n')
|
2020-05-14 21:15:06 +00:00
|
|
|
|
2019-08-15 01:12:24 +00:00
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
parser = argparse.ArgumentParser(description='ONNX Runtime Operator Kernel Documentation Generator')
|
2021-06-02 07:47:40 +00:00
|
|
|
parser.add_argument('--providers', nargs='+',
|
|
|
|
|
help="Filter to specified execution providers. Case-insensitive. "
|
|
|
|
|
"Matches provider names from <ORT>/include/onnxruntime/core/graph/constants.h'. "
|
|
|
|
|
"'ExecutionProvider' is automatically appended as needed. "
|
|
|
|
|
"e.g. `--providers cpu cuda` will match CPUExecutionProvider and CUDAExecutionProvider.")
|
|
|
|
|
parser.add_argument('--output_path', help='output markdown file path', type=pathlib.Path, required=True,
|
|
|
|
|
default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'OperatorKernels.md'))
|
2019-08-15 01:12:24 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
2021-06-02 07:47:40 +00:00
|
|
|
main(args.output_path, args.providers)
|