mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
add documentation for custom ops (#708)
* added tools for doc gen, added doc * doc updated * some fixes * hooked up with build.py * hooked up with build.py and fail on nonupdated doc * update
This commit is contained in:
parent
77b981824a
commit
83ae641425
5 changed files with 1699 additions and 0 deletions
|
|
@ -29,6 +29,7 @@ if(NOT NUMPY_INCLUDE_DIR)
|
|||
endif(${NUMPY_NOT_FOUND})
|
||||
endif(NOT NUMPY_INCLUDE_DIR)
|
||||
|
||||
|
||||
# ---[ Python + Numpy
|
||||
set(onnxruntime_pybind_srcs_pattern
|
||||
"${ONNXRUNTIME_ROOT}/python/*.cc"
|
||||
|
|
@ -42,6 +43,11 @@ add_library(onnxruntime_pybind11_state MODULE ${onnxruntime_pybind_srcs})
|
|||
if(HAS_CAST_FUNCTION_TYPE)
|
||||
target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type")
|
||||
endif()
|
||||
|
||||
if(onnxruntime_PYBIND_EXPORT_OPSCHEMA)
|
||||
target_compile_definitions(onnxruntime_pybind11_state PRIVATE onnxruntime_PYBIND_EXPORT_OPSCHEMA)
|
||||
endif()
|
||||
|
||||
target_include_directories(onnxruntime_pybind11_state PRIVATE ${ONNXRUNTIME_ROOT} ${PYTHON_INCLUDE_DIR} ${NUMPY_INCLUDE_DIR})
|
||||
target_include_directories(onnxruntime_pybind11_state PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
onnxruntime_add_include_to_target(onnxruntime_pybind11_state gsl)
|
||||
|
|
|
|||
1182
docs/ContribOperators.md
Normal file
1182
docs/ContribOperators.md
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -232,8 +232,108 @@ void addGlobalMethods(py::module& m) {
|
|||
m.def(
|
||||
"get_device", []() -> std::string { return BACKEND_DEVICE; },
|
||||
"Return the device used to compute the prediction (CPU, MKL, ...)");
|
||||
|
||||
#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
|
||||
m.def(
|
||||
"get_all_operator_schema",
|
||||
[]() -> const std::vector<ONNX_NAMESPACE::OpSchema> {
|
||||
return ONNX_NAMESPACE::OpSchemaRegistry::get_all_schemas_with_history();
|
||||
},
|
||||
"Return a vector of OpSchema all registed operators"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
|
||||
|
||||
void addOpSchemaSubmodule(py::module& m){
|
||||
auto schemadef = m.def_submodule("schemadef");
|
||||
schemadef.doc() = "Schema submodule";
|
||||
|
||||
py::class_<ONNX_NAMESPACE::OpSchema> op_schema(schemadef, "OpSchema");
|
||||
op_schema.def_property_readonly("file", &ONNX_NAMESPACE::OpSchema::file)
|
||||
.def_property_readonly("line", &ONNX_NAMESPACE::OpSchema::line)
|
||||
.def_property_readonly("support_level", &ONNX_NAMESPACE::OpSchema::support_level)
|
||||
.def_property_readonly(
|
||||
"doc", &ONNX_NAMESPACE::OpSchema::doc, py::return_value_policy::reference)
|
||||
.def_property_readonly("since_version", &ONNX_NAMESPACE::OpSchema::since_version)
|
||||
.def_property_readonly("deprecated", &ONNX_NAMESPACE::OpSchema::deprecated)
|
||||
.def_property_readonly("domain", &ONNX_NAMESPACE::OpSchema::domain)
|
||||
.def_property_readonly("name", &ONNX_NAMESPACE::OpSchema::Name)
|
||||
.def_property_readonly("min_input", &ONNX_NAMESPACE::OpSchema::min_input)
|
||||
.def_property_readonly("max_input", &ONNX_NAMESPACE::OpSchema::max_input)
|
||||
.def_property_readonly("min_output", &ONNX_NAMESPACE::OpSchema::min_output)
|
||||
.def_property_readonly("max_output", &ONNX_NAMESPACE::OpSchema::max_output)
|
||||
.def_property_readonly("attributes", &ONNX_NAMESPACE::OpSchema::attributes)
|
||||
.def_property_readonly("inputs", &ONNX_NAMESPACE::OpSchema::inputs)
|
||||
.def_property_readonly("outputs", &ONNX_NAMESPACE::OpSchema::outputs)
|
||||
.def_property_readonly(
|
||||
"has_type_and_shape_inference_function",
|
||||
&ONNX_NAMESPACE::OpSchema::has_type_and_shape_inference_function)
|
||||
.def_property_readonly(
|
||||
"type_constraints", &ONNX_NAMESPACE::OpSchema::typeConstraintParams)
|
||||
.def_static("is_infinite", [](int v) {
|
||||
return v == std::numeric_limits<int>::max();
|
||||
});
|
||||
|
||||
py::class_<ONNX_NAMESPACE::OpSchema::Attribute>(op_schema, "Attribute")
|
||||
.def_readonly("name", &ONNX_NAMESPACE::OpSchema::Attribute::name)
|
||||
.def_readonly("description", &ONNX_NAMESPACE::OpSchema::Attribute::description)
|
||||
.def_readonly("type", &ONNX_NAMESPACE::OpSchema::Attribute::type)
|
||||
.def_property_readonly(
|
||||
"_default_value",
|
||||
[](ONNX_NAMESPACE::OpSchema::Attribute* attr) -> py::bytes {
|
||||
std::string out;
|
||||
attr->default_value.SerializeToString(&out);
|
||||
return out;
|
||||
})
|
||||
.def_readonly("required", &ONNX_NAMESPACE::OpSchema::Attribute::required);
|
||||
|
||||
py::class_<ONNX_NAMESPACE::OpSchema::TypeConstraintParam>(op_schema, "TypeConstraintParam")
|
||||
.def_readonly(
|
||||
"type_param_str", &ONNX_NAMESPACE::OpSchema::TypeConstraintParam::type_param_str)
|
||||
.def_readonly("description", &ONNX_NAMESPACE::OpSchema::TypeConstraintParam::description)
|
||||
.def_readonly(
|
||||
"allowed_type_strs",
|
||||
&ONNX_NAMESPACE::OpSchema::TypeConstraintParam::allowed_type_strs);
|
||||
|
||||
py::enum_<ONNX_NAMESPACE::OpSchema::FormalParameterOption>(op_schema, "FormalParameterOption")
|
||||
.value("Single", ONNX_NAMESPACE::OpSchema::Single)
|
||||
.value("Optional", ONNX_NAMESPACE::OpSchema::Optional)
|
||||
.value("Variadic", ONNX_NAMESPACE::OpSchema::Variadic);
|
||||
|
||||
py::class_<ONNX_NAMESPACE::OpSchema::FormalParameter>(op_schema, "FormalParameter")
|
||||
.def_property_readonly("name", &ONNX_NAMESPACE::OpSchema::FormalParameter::GetName)
|
||||
.def_property_readonly("types", &ONNX_NAMESPACE::OpSchema::FormalParameter::GetTypes)
|
||||
.def_property_readonly("typeStr", &ONNX_NAMESPACE::OpSchema::FormalParameter::GetTypeStr)
|
||||
.def_property_readonly(
|
||||
"description", &ONNX_NAMESPACE::OpSchema::FormalParameter::GetDescription)
|
||||
.def_property_readonly("option", &ONNX_NAMESPACE::OpSchema::FormalParameter::GetOption)
|
||||
.def_property_readonly(
|
||||
"isHomogeneous", &ONNX_NAMESPACE::OpSchema::FormalParameter::GetIsHomogeneous);
|
||||
|
||||
py::enum_<ONNX_NAMESPACE::AttributeProto::AttributeType>(op_schema, "AttrType")
|
||||
.value("FLOAT", ONNX_NAMESPACE::AttributeProto::FLOAT)
|
||||
.value("INT", ONNX_NAMESPACE::AttributeProto::INT)
|
||||
.value("STRING", ONNX_NAMESPACE::AttributeProto::STRING)
|
||||
.value("TENSOR", ONNX_NAMESPACE::AttributeProto::TENSOR)
|
||||
.value("GRAPH", ONNX_NAMESPACE::AttributeProto::GRAPH)
|
||||
.value("FLOATS", ONNX_NAMESPACE::AttributeProto::FLOATS)
|
||||
.value("INTS", ONNX_NAMESPACE::AttributeProto::INTS)
|
||||
.value("STRINGS", ONNX_NAMESPACE::AttributeProto::STRINGS)
|
||||
.value("TENSORS", ONNX_NAMESPACE::AttributeProto::TENSORS)
|
||||
.value("GRAPHS", ONNX_NAMESPACE::AttributeProto::GRAPHS);
|
||||
|
||||
py::enum_<ONNX_NAMESPACE::OpSchema::SupportType>(op_schema, "SupportType")
|
||||
.value("COMMON", ONNX_NAMESPACE::OpSchema::SupportType::COMMON)
|
||||
.value("EXPERIMENTAL", ONNX_NAMESPACE::OpSchema::SupportType::EXPERIMENTAL);
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif //onnxruntime_PYBIND_EXPORT_OPSCHEMA
|
||||
|
||||
void addObjectMethods(py::module& m) {
|
||||
// allow unit tests to redirect std::cout and std::cerr to sys.stdout and sys.stderr
|
||||
py::add_ostream_redirect(m, "onnxruntime_ostream_redirect");
|
||||
|
|
@ -455,6 +555,11 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
|
|||
|
||||
addGlobalMethods(m);
|
||||
addObjectMethods(m);
|
||||
|
||||
#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
|
||||
addOpSchemaSubmodule(m);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
|
|
|
|||
379
onnxruntime/python/tools/gen_doc.py
Normal file
379
onnxruntime/python/tools/gen_doc.py
Normal file
|
|
@ -0,0 +1,379 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# This file is copied and adapted from https://github.com/onnx/onnx repository.
|
||||
# There was no copyright statement on the file at the time of copying.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from collections import defaultdict
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import numpy as np # type: ignore
|
||||
|
||||
import onnxruntime as rt
|
||||
import onnxruntime.capi.onnxruntime_pybind11_state as rtpy
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import schemadef
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state.schemadef import OpSchema #, ONNX_DOMAIN, ONNX_ML_DOMAIN
|
||||
from typing import Any, Text, Sequence, Dict, List, Type, Set, Tuple
|
||||
|
||||
|
||||
ONNX_ML = not bool(os.getenv('ONNX_ML') == '0')
|
||||
ONNX_DOMAIN = "onnx"
|
||||
ONNX_ML_DOMAIN = "onnx-ml"
|
||||
|
||||
if ONNX_ML:
|
||||
ext = '-ml.md'
|
||||
else:
|
||||
ext = '.md'
|
||||
|
||||
|
||||
def display_number(v): # type: (int) -> Text
|
||||
if OpSchema.is_infinite(v):
|
||||
return '∞'
|
||||
return Text(v)
|
||||
|
||||
|
||||
def should_render_domain(domain): # type: (Text) -> bool
|
||||
if domain == ONNX_DOMAIN or domain == '' or domain == ONNX_ML_DOMAIN or domain == 'ai.onnx.ml':
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def format_name_with_domain(domain, schema_name): # type: (Text, Text) -> Text
|
||||
if domain:
|
||||
return '{}.{}'.format(domain, schema_name)
|
||||
else:
|
||||
return schema_name
|
||||
|
||||
|
||||
def display_attr_type(v): # type: (OpSchema.AttrType) -> Text
|
||||
assert isinstance(v, OpSchema.AttrType)
|
||||
s = Text(v)
|
||||
s = s[s.rfind('.') + 1:].lower()
|
||||
if s[-1] == 's':
|
||||
s = 'list of ' + s
|
||||
return s
|
||||
|
||||
|
||||
def display_domain(domain): # type: (Text) -> Text
|
||||
if domain:
|
||||
return "the '{}' operator set".format(domain)
|
||||
else:
|
||||
return "the default ONNX operator set"
|
||||
|
||||
|
||||
def display_domain_short(domain): # type: (Text) -> Text
|
||||
if domain:
|
||||
return domain
|
||||
else:
|
||||
return 'ai.onnx (default)'
|
||||
|
||||
|
||||
def display_version_link(name, version): # type: (Text, int) -> Text
|
||||
changelog_md = 'Changelog' + ext
|
||||
name_with_ver = '{}-{}'.format(name, version)
|
||||
return '<a href="{}#{}">{}</a>'.format(changelog_md, name_with_ver, name_with_ver)
|
||||
|
||||
|
||||
def display_function_version_link(name, version): # type: (Text, int) -> Text
|
||||
changelog_md = 'FunctionsChangelog' + ext
|
||||
name_with_ver = '{}-{}'.format(name, version)
|
||||
return '<a href="{}#{}">{}</a>'.format(changelog_md, name_with_ver, name_with_ver)
|
||||
|
||||
|
||||
def get_attribute_value(attr): # type: (AttributeProto) -> Any
|
||||
if attr.HasField('f'):
|
||||
return attr.f
|
||||
elif attr.HasField('i'):
|
||||
return attr.i
|
||||
elif attr.HasField('s'):
|
||||
return attr.s
|
||||
elif attr.HasField('t'):
|
||||
return attr.t
|
||||
elif attr.HasField('g'):
|
||||
return attr.g
|
||||
elif len(attr.floats):
|
||||
return list(attr.floats)
|
||||
elif len(attr.ints):
|
||||
return list(attr.ints)
|
||||
elif len(attr.strings):
|
||||
return list(attr.strings)
|
||||
elif len(attr.tensors):
|
||||
return list(attr.tensors)
|
||||
elif len(attr.graphs):
|
||||
return list(attr.graphs)
|
||||
else:
|
||||
raise ValueError("Unsupported ONNX attribute: {}".format(attr))
|
||||
|
||||
def display_schema(schema, versions): # type: (OpSchema, Sequence[OpSchema]) -> Text
|
||||
s = ''
|
||||
|
||||
# doc
|
||||
if schema.doc:
|
||||
s += '\n'
|
||||
s += '\n'.join(' ' + line
|
||||
for line in schema.doc.lstrip().splitlines())
|
||||
s += '\n'
|
||||
|
||||
# since version
|
||||
s += '\n#### Version\n'
|
||||
if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
|
||||
s += '\nNo versioning maintained for experimental ops.'
|
||||
else:
|
||||
s += '\nThis version of the operator has been ' + ('deprecated' if schema.deprecated else 'available') + ' since version {}'.format(schema.since_version)
|
||||
s += ' of {}.\n'.format(display_domain(schema.domain))
|
||||
if len(versions) > 1:
|
||||
# TODO: link to the Changelog.md
|
||||
s += '\nOther versions of this operator: {}\n'.format(
|
||||
', '.join(display_version_link(format_name_with_domain(v.domain, v.name),
|
||||
v.since_version) for v in versions[:-1]))
|
||||
|
||||
# If this schema is deprecated, don't display any of the following sections
|
||||
if schema.deprecated:
|
||||
return s
|
||||
|
||||
# attributes
|
||||
if schema.attributes:
|
||||
s += '\n#### Attributes\n\n'
|
||||
s += '<dl>\n'
|
||||
for _, attr in sorted(schema.attributes.items()):
|
||||
# option holds either required or default value
|
||||
opt = ''
|
||||
if attr.required:
|
||||
opt = 'required'
|
||||
elif hasattr(attr, 'default_value') and attr.default_value.name:
|
||||
default_value = get_attribute_value(attr.default_value)
|
||||
|
||||
def format_value(value): # type: (Any) -> Text
|
||||
if isinstance(value, float):
|
||||
value = np.round(value, 5)
|
||||
if isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3:
|
||||
value = value.decode('utf-8')
|
||||
return str(value)
|
||||
|
||||
if isinstance(default_value, list):
|
||||
default_value = [format_value(val) for val in default_value]
|
||||
else:
|
||||
default_value = format_value(default_value)
|
||||
opt = 'default is {}'.format(default_value)
|
||||
|
||||
s += '<dt><tt>{}</tt> : {}{}</dt>\n'.format(
|
||||
attr.name,
|
||||
display_attr_type(attr.type),
|
||||
' ({})'.format(opt) if opt else '')
|
||||
s += '<dd>{}</dd>\n'.format(attr.description)
|
||||
s += '</dl>\n'
|
||||
|
||||
# inputs
|
||||
s += '\n#### Inputs'
|
||||
if schema.min_input != schema.max_input:
|
||||
s += ' ({} - {})'.format(display_number(schema.min_input),
|
||||
display_number(schema.max_input))
|
||||
s += '\n\n'
|
||||
if schema.inputs:
|
||||
s += '<dl>\n'
|
||||
for input in schema.inputs:
|
||||
option_str = ""
|
||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||
option_str = " (optional)"
|
||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||
if input.isHomogeneous:
|
||||
option_str = " (variadic)"
|
||||
else:
|
||||
option_str = " (variadic, heterogeneous)"
|
||||
s += '<dt><tt>{}</tt>{} : {}</dt>\n'.format(input.name, option_str, input.typeStr)
|
||||
s += '<dd>{}</dd>\n'.format(input.description)
|
||||
s += '</dl>\n'
|
||||
|
||||
# outputs
|
||||
s += '\n#### Outputs'
|
||||
if schema.min_output != schema.max_output:
|
||||
s += ' ({} - {})'.format(display_number(schema.min_output),
|
||||
display_number(schema.max_output))
|
||||
s += '\n\n'
|
||||
|
||||
if schema.outputs:
|
||||
s += '<dl>\n'
|
||||
for output in schema.outputs:
|
||||
option_str = ""
|
||||
if OpSchema.FormalParameterOption.Optional == output.option:
|
||||
option_str = " (optional)"
|
||||
elif OpSchema.FormalParameterOption.Variadic == output.option:
|
||||
if output.isHomogeneous:
|
||||
option_str = " (variadic)"
|
||||
else:
|
||||
option_str = " (variadic, heterogeneous)"
|
||||
s += '<dt><tt>{}</tt>{} : {}</dt>\n'.format(output.name, option_str, output.typeStr)
|
||||
s += '<dd>{}</dd>\n'.format(output.description)
|
||||
s += '</dl>\n'
|
||||
|
||||
# type constraints
|
||||
s += '\n#### Type Constraints'
|
||||
s += '\n\n'
|
||||
if schema.type_constraints:
|
||||
s += '<dl>\n'
|
||||
for type_constraint in schema.type_constraints:
|
||||
allowedTypes = type_constraint.allowed_type_strs
|
||||
allowedTypeStr = ''
|
||||
if (len(allowedTypes) > 0):
|
||||
allowedTypeStr = allowedTypes[0]
|
||||
for allowedType in allowedTypes[1:]:
|
||||
allowedTypeStr += ', ' + allowedType
|
||||
s += '<dt><tt>{}</tt> : {}</dt>\n'.format(
|
||||
type_constraint.type_param_str, allowedTypeStr)
|
||||
s += '<dd>{}</dd>\n'.format(type_constraint.description)
|
||||
s += '</dl>\n'
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def display_function(function, versions, domain=ONNX_DOMAIN): # type: (FunctionProto, List[int], Text) -> Text
|
||||
s = ''
|
||||
|
||||
if domain:
|
||||
domain_prefix = '{}.'.format(ONNX_ML_DOMAIN)
|
||||
else:
|
||||
domain_prefix = ''
|
||||
|
||||
# doc
|
||||
if function.doc_string:
|
||||
s += '\n'
|
||||
s += '\n'.join(' ' + line
|
||||
for line in function.doc_string.lstrip().splitlines())
|
||||
s += '\n'
|
||||
|
||||
# since version
|
||||
s += '\n#### Version\n'
|
||||
s += '\nThis version of the function has been available since version {}'.format(function.since_version)
|
||||
s += ' of {}.\n'.format(display_domain(domain_prefix))
|
||||
if len(versions) > 1:
|
||||
s += '\nOther versions of this function: {}\n'.format(
|
||||
', '.join(display_function_version_link(domain_prefix + function.name, v) for v in versions if v != function.since_version))
|
||||
|
||||
# inputs
|
||||
s += '\n#### Inputs'
|
||||
s += '\n\n'
|
||||
if function.input:
|
||||
s += '<dl>\n'
|
||||
for input in function.input:
|
||||
s += '<dt>{}; </dt>\n'.format(input)
|
||||
s += '<br/></dl>\n'
|
||||
|
||||
# outputs
|
||||
s += '\n#### Outputs'
|
||||
s += '\n\n'
|
||||
if function.output:
|
||||
s += '<dl>\n'
|
||||
for output in function.output:
|
||||
s += '<dt>{}; </dt>\n'.format(output)
|
||||
s += '<br/></dl>\n'
|
||||
|
||||
# attributes
|
||||
if function.attribute:
|
||||
s += '\n#### Attributes\n\n'
|
||||
s += '<dl>\n'
|
||||
for attr in function.attribute:
|
||||
s += '<dt>{};<br/></dt>\n'.format(attr)
|
||||
s += '</dl>\n'
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def support_level_str(level): # type: (OpSchema.SupportType) -> Text
|
||||
return \
|
||||
"<sub>experimental</sub> " if level == OpSchema.SupportType.EXPERIMENTAL else ""
|
||||
|
||||
|
||||
# def function_status_str(status=OperatorStatus.Value("EXPERIMENTAL")): # type: ignore
|
||||
# return \
|
||||
# "<sub>experimental</sub> " if status == OperatorStatus.Value('EXPERIMENTAL') else "" # type: ignore
|
||||
|
||||
|
||||
def main(args): # type: (Type[Args]) -> None
|
||||
|
||||
with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
|
||||
fout.write('## Contrib Operator Schemas\n')
|
||||
fout.write(
|
||||
"*This file is automatically generated from the\n"
|
||||
" [def files](/onnxruntime/core/graph/contrib_ops/contrib_defs.cc) via [this script](/onnxruntime/python/tools/gen_doc.py).\n"
|
||||
" Do not modify directly and instead edit operator definitions.*\n")
|
||||
|
||||
# domain -> support level -> name -> [schema]
|
||||
index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
|
||||
for schema in rtpy.get_all_operator_schema():
|
||||
index[schema.domain][int(schema.support_level)][schema.name].append(schema)
|
||||
|
||||
fout.write('\n')
|
||||
|
||||
# Preprocess the Operator Schemas
|
||||
# [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
|
||||
operator_schemas = list() # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
|
||||
exsting_ops = set() # type: Set[Text]
|
||||
for domain, _supportmap in sorted(index.items()):
|
||||
if not should_render_domain(domain):
|
||||
continue
|
||||
|
||||
processed_supportmap = list()
|
||||
for _support, _namemap in sorted(_supportmap.items()):
|
||||
processed_namemap = list()
|
||||
for n, unsorted_versions in sorted(_namemap.items()):
|
||||
versions = sorted(unsorted_versions, key=lambda s: s.since_version)
|
||||
schema = versions[-1]
|
||||
if schema.name in exsting_ops:
|
||||
continue
|
||||
exsting_ops.add(schema.name)
|
||||
processed_namemap.append((n, schema, versions))
|
||||
processed_supportmap.append((_support, processed_namemap))
|
||||
operator_schemas.append((domain, processed_supportmap))
|
||||
|
||||
# Table of contents
|
||||
for domain, supportmap in operator_schemas:
|
||||
s = '* {}\n'.format(display_domain_short(domain))
|
||||
fout.write(s)
|
||||
|
||||
for _, namemap in supportmap:
|
||||
for n, schema, versions in namemap:
|
||||
s = ' * {}<a href="#{}">{}</a>\n'.format(
|
||||
support_level_str(schema.support_level),
|
||||
format_name_with_domain(domain, n),
|
||||
format_name_with_domain(domain, n))
|
||||
fout.write(s)
|
||||
|
||||
fout.write('\n')
|
||||
|
||||
for domain, supportmap in operator_schemas:
|
||||
s = '## {}\n'.format(display_domain_short(domain))
|
||||
fout.write(s)
|
||||
|
||||
for _, namemap in supportmap:
|
||||
for op_type, schema, versions in namemap:
|
||||
# op_type
|
||||
s = ('### {}<a name="{}"></a><a name="{}">**{}**' + (' (deprecated)' if schema.deprecated else '') + '</a>\n').format(
|
||||
support_level_str(schema.support_level),
|
||||
format_name_with_domain(domain, op_type),
|
||||
format_name_with_domain(domain, op_type.lower()),
|
||||
format_name_with_domain(domain, op_type))
|
||||
|
||||
s += display_schema(schema, versions)
|
||||
|
||||
s += '\n\n'
|
||||
|
||||
fout.write(s)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='ONNX Runtime Operator Documentation Generator')
|
||||
parser.add_argument('--output_path', help='output markdown file path',
|
||||
default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'ContribOperators.md')
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
class Args(object):
|
||||
output = args.output_path
|
||||
main(Args)
|
||||
|
|
@ -69,6 +69,10 @@ Use the individual flags to only run the specified stages.
|
|||
help='''Downloads test data without running the tests''')
|
||||
parser.add_argument("--test_data_url", help="Test data URL.")
|
||||
parser.add_argument("--test_data_checksum", help="Test data checksum (MD5 digest).")
|
||||
|
||||
# generate documentaiton
|
||||
parser.add_argument("--gen_doc", action='store_true', help="Generate documentation on contrib ops")
|
||||
|
||||
# CUDA related
|
||||
parser.add_argument("--use_cuda", action='store_true', help="Enable CUDA.")
|
||||
parser.add_argument("--cuda_version", help="The version of CUDA toolkit to use. Auto-detect if not specified. e.g. 9.0")
|
||||
|
|
@ -341,6 +345,9 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home
|
|||
if path_to_protoc_exe:
|
||||
cmake_args += ["-DONNX_CUSTOM_PROTOC_EXECUTABLE=%s" % path_to_protoc_exe]
|
||||
|
||||
if args.gen_doc:
|
||||
cmake_args += ["-Donnxruntime_PYBIND_EXPORT_OPSCHEMA=ON"]
|
||||
|
||||
cmake_args += ["-D{}".format(define) for define in cmake_extra_defines]
|
||||
|
||||
if is_windows():
|
||||
|
|
@ -588,6 +595,23 @@ def build_protoc_for_windows_host(cmake_path, source_dir, build_dir):
|
|||
if not os.path.exists(os.path.join(build_dir, 'host_protoc', 'Release', 'protoc.exe')):
|
||||
raise BuildError("Couldn't build protoc.exe for host. Failing build.")
|
||||
|
||||
def generate_documentation(source_dir, build_dir, configs):
|
||||
operator_doc_path = os.path.join(source_dir, 'docs', 'ContribOperators.md')
|
||||
for config in configs:
|
||||
#copy the gen_doc.py
|
||||
shutil.copy(os.path.join(source_dir,'onnxruntime','python','tools','gen_doc.py'),
|
||||
os.path.join(build_dir,config, config))
|
||||
run_subprocess([
|
||||
sys.executable,
|
||||
'gen_doc.py',
|
||||
'--output_path', operator_doc_path
|
||||
],
|
||||
cwd = os.path.join(build_dir,config, config))
|
||||
|
||||
docdiff = run_subprocess(['git', 'diff', operator_doc_path], capture=True).stdout
|
||||
if len(docdiff) > 0:
|
||||
raise BuildError("The updated operator document file "+operator_doc_path+" must be checked in")
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
|
|
@ -715,6 +739,9 @@ def main():
|
|||
if args.build:
|
||||
if args.build_wheel:
|
||||
build_python_wheel(source_dir, build_dir, configs, args.use_cuda, args.use_tensorrt)
|
||||
|
||||
if args.gen_doc:
|
||||
generate_documentation(source_dir, build_dir, configs)
|
||||
|
||||
log.info("Build complete")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue