2019-03-08 22:51:30 +00:00
|
|
|
import argparse
|
|
|
|
|
import os
|
|
|
|
|
|
2022-04-26 16:35:16 +00:00
|
|
|
import onnx
|
|
|
|
|
|
2019-03-08 22:51:30 +00:00
|
|
|
|
|
|
|
|
def export_and_recurse(node, attribute, output_dir, level):
|
|
|
|
|
name = node.name
|
2022-04-26 16:35:16 +00:00
|
|
|
name = name.replace("/", "_")
|
2019-03-08 22:51:30 +00:00
|
|
|
sub_model = onnx.ModelProto()
|
|
|
|
|
sub_model.graph.MergeFrom(attribute.g)
|
2022-04-26 16:35:16 +00:00
|
|
|
filename = "L" + str(level) + "_" + node.op_type + "_" + attribute.name + "_" + name + ".onnx"
|
2019-03-08 22:51:30 +00:00
|
|
|
onnx.save_model(sub_model, os.path.join(output_dir, filename))
|
|
|
|
|
dump_subgraph(sub_model, output_dir, level + 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dump_subgraph(model, output_dir, level=0):
|
|
|
|
|
graph = model.graph
|
|
|
|
|
|
|
|
|
|
for node in graph.node:
|
|
|
|
|
if node.op_type == "Scan" or node.op_type == "Loop":
|
2023-10-06 04:07:33 +00:00
|
|
|
body_attribute = next(iter(filter(lambda attr: attr.name == "body", node.attribute)))
|
2019-03-08 22:51:30 +00:00
|
|
|
export_and_recurse(node, body_attribute, output_dir, level)
|
|
|
|
|
if node.op_type == "If":
|
2023-10-06 04:07:33 +00:00
|
|
|
then_attribute = next(iter(filter(lambda attr: attr.name == "then_branch", node.attribute)))
|
|
|
|
|
else_attribute = next(iter(filter(lambda attr: attr.name == "else_branch", node.attribute)))
|
2019-03-08 22:51:30 +00:00
|
|
|
export_and_recurse(node, then_attribute, output_dir, level)
|
|
|
|
|
export_and_recurse(node, else_attribute, output_dir, level)
|
|
|
|
|
|
2020-05-14 21:15:06 +00:00
|
|
|
|
2019-03-08 22:51:30 +00:00
|
|
|
def parse_args():
|
2022-04-26 16:35:16 +00:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
os.path.basename(__file__), description="Dump all subgraphs from an ONNX model into separate onnx files."
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument("-m", "--model", required=True, help="model file")
|
|
|
|
|
parser.add_argument("-o", "--out", required=True, help="output directory")
|
2019-03-08 22:51:30 +00:00
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
args = parse_args()
|
|
|
|
|
|
|
|
|
|
model_path = args.model
|
|
|
|
|
out = os.path.abspath(args.out)
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(out):
|
2020-06-08 17:27:32 +00:00
|
|
|
os.makedirs(out)
|
2019-03-08 22:51:30 +00:00
|
|
|
|
|
|
|
|
model = onnx.load_model(model_path)
|
|
|
|
|
dump_subgraph(model, out)
|
|
|
|
|
|
2020-05-14 21:15:06 +00:00
|
|
|
|
2022-04-26 16:35:16 +00:00
|
|
|
if __name__ == "__main__":
|
2019-03-08 22:51:30 +00:00
|
|
|
main()
|