mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-15 01:23:42 +00:00
* Support per-channel quantization of weight tensor * rename util functions * fix bugs in calibrate * add support of reduce_range * refine opset check
127 lines
3.9 KiB
Python
127 lines
3.9 KiB
Python
import onnx
|
|
from .quant_utils import find_by_name
|
|
|
|
|
|
class ONNXModel:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
self.node_name_counter = {}
|
|
|
|
def nodes(self):
|
|
return self.model.graph.node
|
|
|
|
def initializer(self):
|
|
return self.model.graph.initializer
|
|
|
|
def graph(self):
|
|
return self.model.graph
|
|
|
|
def ir_version(self):
|
|
return self.model.ir_version
|
|
|
|
def opset_import(self):
|
|
return self.model.opset_import
|
|
|
|
def remove_node(self, node):
|
|
if node in self.model.graph.node:
|
|
self.model.graph.node.remove(node)
|
|
|
|
def remove_nodes(self, nodes_to_remove):
|
|
for node in nodes_to_remove:
|
|
self.remove_node(node)
|
|
|
|
def add_node(self, node):
|
|
self.model.graph.node.extend([node])
|
|
|
|
def add_nodes(self, nodes_to_add):
|
|
self.model.graph.node.extend(nodes_to_add)
|
|
|
|
def add_initializer(self, tensor):
|
|
if find_by_name(tensor.name, self.model.graph.initializer) is None:
|
|
self.model.graph.initializer.extend([tensor])
|
|
|
|
def get_initializer(self, name):
|
|
for tensor in self.model.graph.initializer:
|
|
if tensor.name == name:
|
|
return tensor
|
|
return None
|
|
|
|
def remove_initializer(self, tensor):
|
|
if tensor in self.model.graph.initializer:
|
|
self.model.graph.initializer.remove(tensor)
|
|
|
|
def remove_initializers(self, init_to_remove):
|
|
for initializer in init_to_remove:
|
|
self.remove_initializer(initializer)
|
|
|
|
def input_name_to_nodes(self):
|
|
input_name_to_nodes = {}
|
|
for node in self.model.graph.node:
|
|
for input_name in node.input:
|
|
if input_name not in input_name_to_nodes:
|
|
input_name_to_nodes[input_name] = [node]
|
|
else:
|
|
input_name_to_nodes[input_name].append(node)
|
|
return input_name_to_nodes
|
|
|
|
def output_name_to_node(self):
|
|
output_name_to_node = {}
|
|
for node in self.model.graph.node:
|
|
for output_name in node.output:
|
|
output_name_to_node[output_name] = node
|
|
return output_name_to_node
|
|
|
|
def get_children(self, node, input_name_to_nodes=None):
|
|
if input_name_to_nodes is None:
|
|
input_name_to_nodes = self.input_name_to_nodes()
|
|
|
|
children = []
|
|
for output in node.output:
|
|
if output in input_name_to_nodes:
|
|
for node in input_name_to_nodes[output]:
|
|
children.append(node)
|
|
return children
|
|
|
|
def get_parents(self, node, output_name_to_node=None):
|
|
if output_name_to_node is None:
|
|
output_name_to_node = self.output_name_to_node()
|
|
|
|
parents = []
|
|
for input in node.input:
|
|
if input in output_name_to_node:
|
|
parents.append(output_name_to_node[input])
|
|
return parents
|
|
|
|
def get_parent(self, node, idx, output_name_to_node=None):
|
|
if output_name_to_node is None:
|
|
output_name_to_node = self.output_name_to_node()
|
|
|
|
if len(node.input) <= idx:
|
|
return None
|
|
|
|
input = node.input[idx]
|
|
if input not in output_name_to_node:
|
|
return None
|
|
|
|
return output_name_to_node[input]
|
|
|
|
def find_node_by_name(self, node_name, new_nodes_list, graph):
|
|
'''
|
|
Find out if a node exists in a graph or a node is in the
|
|
new set of nodes created during quantization. Return the node found.
|
|
'''
|
|
graph_nodes_list = list(graph.node) #deep copy
|
|
graph_nodes_list.extend(new_nodes_list)
|
|
node = find_by_name(node_name, graph_nodes_list)
|
|
return node
|
|
|
|
def find_nodes_by_initializer(self, graph, initializer):
|
|
'''
|
|
Find all nodes with given initializer as an input.
|
|
'''
|
|
nodes = []
|
|
for node in graph.node:
|
|
for node_input in node.input:
|
|
if node_input == initializer.name:
|
|
nodes.append(node)
|
|
return nodes
|