onnxruntime/tools/python/dump_subgraphs.py
Scott McKay 5e0928a777
Enable running PEP8 on python scripts using flake8 (#3928)
* Enable running PEP8 checks via flake8 as part of the build if flake8 is installed.
Update scripts in \tools and \onnxruntime\python. Excluding \onnxruntime\python\tools which needs a lot more work to be PEP8 compliant. Also excluding orttraining\tools for the same reason.
Install flake8 as part of the static_analysis build task in the Win-CPU CI so the checks are run in one CI build.
Update coding standards doc.
2020-05-15 07:15:06 +10:00

52 lines
1.8 KiB
Python

import onnx
import argparse
import os
def export_and_recurse(node, attribute, output_dir, level):
name = node.name
name = name.replace('/', '_')
sub_model = onnx.ModelProto()
sub_model.graph.MergeFrom(attribute.g)
filename = 'L' + str(level) + '_' + node.op_type + '_' + attribute.name + '_' + name + '.onnx'
onnx.save_model(sub_model, os.path.join(output_dir, filename))
dump_subgraph(sub_model, output_dir, level + 1)
def dump_subgraph(model, output_dir, level=0):
graph = model.graph
for node in graph.node:
if node.op_type == "Scan" or node.op_type == "Loop":
body_attribute = list(filter(lambda attr: attr.name == 'body', node.attribute))[0]
export_and_recurse(node, body_attribute, output_dir, level)
if node.op_type == "If":
then_attribute = list(filter(lambda attr: attr.name == 'then_branch', node.attribute))[0]
else_attribute = list(filter(lambda attr: attr.name == 'else_branch', node.attribute))[0]
export_and_recurse(node, then_attribute, output_dir, level)
export_and_recurse(node, else_attribute, output_dir, level)
def parse_args():
parser = argparse.ArgumentParser(os.path.basename(__file__),
description='Dump all subgraphs from an ONNX model into separate onnx files.')
parser.add_argument('-m', '--model', required=True, help='model file')
parser.add_argument('-o', '--out', required=True, help='output directory')
return parser.parse_args()
def main():
args = parse_args()
model_path = args.model
out = os.path.abspath(args.out)
if not os.path.exists(out):
os.mkdirs(out)
model = onnx.load_model(model_path)
dump_subgraph(model, out)
if __name__ == '__main__':
main()