mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
fix topo sort in quant tool (#7833)
* fix topo sort in quant tool * add unit test and make the topo sort stable
This commit is contained in:
parent
fc472a04be
commit
94bb09bf47
2 changed files with 55 additions and 14 deletions
|
|
@ -260,38 +260,49 @@ class ONNXModel:
|
|||
def topological_sort(self):
|
||||
deps_count = [0]*len(self.nodes()) # dependency count of each node
|
||||
deps_to_nodes = {} # input to node indice
|
||||
sorted_nodes = [] # initialize sorted_nodes
|
||||
for node_idx, node in enumerate(self.nodes()):
|
||||
# CANNOT use len(node.input) directly because input can be optional
|
||||
deps_count[node_idx] = sum(1 for _ in node.input if _ )
|
||||
if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
|
||||
sorted_nodes.append(self.nodes()[node_idx])
|
||||
continue
|
||||
|
||||
for input_name in node.input:
|
||||
if input_name not in deps_to_nodes:
|
||||
deps_to_nodes[input_name] = [node_idx]
|
||||
else:
|
||||
deps_to_nodes[input_name].append(node_idx)
|
||||
|
||||
# initialize sorted_nodes
|
||||
sorted_nodes = []
|
||||
for input in itertools.chain(self.initializer(), self.model.graph.input):
|
||||
if input.name in deps_to_nodes:
|
||||
for node_idx in deps_to_nodes[input.name]:
|
||||
initializer_names = [init.name for init in self.initializer()]
|
||||
graph_input_names = [input.name for input in self.model.graph.input]
|
||||
input_names = initializer_names + graph_input_names
|
||||
input_names.sort()
|
||||
prev_input_name = None
|
||||
for input_name in input_names:
|
||||
if prev_input_name == input_name:
|
||||
continue
|
||||
|
||||
prev_input_name = input_name
|
||||
if input_name in deps_to_nodes:
|
||||
for node_idx in deps_to_nodes[input_name]:
|
||||
deps_count[node_idx] = deps_count[node_idx] - 1
|
||||
if deps_count[node_idx] == 0:
|
||||
sorted_nodes.append(self.nodes()[node_idx])
|
||||
|
||||
s = 0
|
||||
e = len(sorted_nodes)
|
||||
start = 0
|
||||
end = len(sorted_nodes)
|
||||
|
||||
while s < e:
|
||||
for output in sorted_nodes[s].output:
|
||||
while start < end:
|
||||
for output in sorted_nodes[start].output:
|
||||
if output in deps_to_nodes:
|
||||
for node_idx in deps_to_nodes[output]:
|
||||
deps_count[node_idx] = deps_count[node_idx] - 1
|
||||
if deps_count[node_idx] == 0:
|
||||
sorted_nodes.append(self.nodes()[node_idx])
|
||||
e = e + 1
|
||||
s = s + 1
|
||||
end = end + 1
|
||||
start = start + 1
|
||||
|
||||
assert(e == len(self.graph().node)), "Graph is not a DAG"
|
||||
assert(end == len(self.graph().node)), "Graph is not a DAG"
|
||||
self.graph().ClearField('node')
|
||||
self.graph().node.extend(sorted_nodes)
|
||||
|
||||
self.graph().node.extend(sorted_nodes)
|
||||
|
|
@ -65,6 +65,28 @@ class TestONNXModel(unittest.TestCase):
|
|||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
onnx.save(model, model_path)
|
||||
|
||||
def construct_model_Constant(self, model_path):
|
||||
# (input) Constant
|
||||
# \ /
|
||||
# \ /
|
||||
# \ /
|
||||
# \ /
|
||||
# Add
|
||||
# |
|
||||
# (output)
|
||||
|
||||
initializers = []
|
||||
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 8, 12])
|
||||
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [4, 8, 12])
|
||||
|
||||
# make nodes
|
||||
constant_node = onnx.helper.make_node('Constant', [], ['const_output'], value_float=42.0)
|
||||
add_node = onnx.helper.make_node('Add', ['input', 'const_output'], ['output'], name='Add')
|
||||
graph = helper.make_graph([add_node, constant_node],
|
||||
'onnx_model_test', [input], [output], initializer=initializers)
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
onnx.save(model, model_path)
|
||||
|
||||
def test_topo_sort(self):
|
||||
test_model_path = 'onnx_model_topo_sort.onnx'
|
||||
self.construct_model(test_model_path)
|
||||
|
|
@ -73,5 +95,13 @@ class TestONNXModel(unittest.TestCase):
|
|||
onnx_model.topological_sort()
|
||||
check_op_type_order(self, onnx_model.model, ['GRU', 'Conv', 'Conv', 'Relu', 'Add'])
|
||||
|
||||
def test_topo_sort_constant(self):
|
||||
test_model_path = 'onnx_model_topo_sort_constant.onnx'
|
||||
self.construct_model_Constant(test_model_path)
|
||||
onnx_model = ONNXModel(onnx.load(test_model_path))
|
||||
check_op_type_order(self, onnx_model.model, ['Add', 'Constant'])
|
||||
onnx_model.topological_sort()
|
||||
check_op_type_order(self, onnx_model.model, ['Constant', 'Add'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue