import argparse import onnx parser = argparse.ArgumentParser(description='ONNX file analyzer for performance investigation.') parser.add_argument('onnx_file', type=str, help='ONNX file to analyze') args = parser.parse_args() def process_file(onnx_file): model = onnx.load(onnx_file) # Map from output arg to the producer of the output. output_to_node = {} for node in model.graph.node: for o in node.output: output_to_node[o] = node aten_ops = [] python_ops = [] memcpu_ops = [] cast_ops = [] msgs = [] for node in model.graph.node: if "Memcpy" in node.op_type: memcpu_ops.append(f"{node.op_type} {node.name}") if node.op_type == "Cast": cast_ops.append(f"{node.name}") if node.op_type == "ATen": for attr in node.attribute: if attr.name == "operator": aten_ops.append(f"{node.name}: {attr.s.decode('utf-8')}") if node.op_type == "PythonOp": for attr in node.attribute: if attr.name == "name": python_ops.append(f"{node.name}: {attr.s.decode('utf-8')}") # Look for stand-alone Dropout node in *_execution_model_.onnx graph. # Examine whether it should be fused with surrounding Add ops into BiasDropout node. if node.op_type == "Dropout" and len(node.input) == 1: prev = output_to_node[node.input[0]] if prev.op_type == "Add": msgs.append(f"Examine whether {node.name} should be fused with the leading {prev.name} op into BiasDropout node.") # Look for stand-alone Softmax node in *_execution_model_.onnx graph. # Examine whether it should be fused with the leading Add ops into BiasSoftmax node. if node.op_type == "Softmax" and len(node.input) == 1: prev = output_to_node[node.input[0]] if prev.op_type == "Add": msgs.append(f"Examine whether {node.name} should be fused with the leading {prev.name} op into BiasSoftmax node.") if aten_ops: print("ATen op found:") for line in aten_ops: print(line) print(10 * '-') if python_ops: print("PythonOp found:") for line in python_ops: print(line) print(10 * '-') if memcpu_ops: print("Memcpu ops found:") for line in memcpu_ops: print(line) print(10 * '-') if cast_ops: print("Cast ops found:") for line in cast_ops: print(line) print(10 * '-') for line in msgs: print(line) def main(): process_file(args.onnx_file) if __name__ == "__main__": main()