fix topo sort in quant tool (#7833)

* fix topo sort in quant tool

* add unit test and make the topo sort stable
This commit is contained in:
Yufeng Li 2021-05-26 17:53:35 -07:00 committed by GitHub
parent fc472a04be
commit 94bb09bf47
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 14 deletions

View file

@ -260,38 +260,49 @@ class ONNXModel:
def topological_sort(self):
deps_count = [0]*len(self.nodes()) # dependency count of each node
deps_to_nodes = {} # input to node indice
sorted_nodes = [] # initialize sorted_nodes
for node_idx, node in enumerate(self.nodes()):
# CANNOT use len(node.input) directly because input can be optional
deps_count[node_idx] = sum(1 for _ in node.input if _ )
if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
sorted_nodes.append(self.nodes()[node_idx])
continue
for input_name in node.input:
if input_name not in deps_to_nodes:
deps_to_nodes[input_name] = [node_idx]
else:
deps_to_nodes[input_name].append(node_idx)
# initialize sorted_nodes
sorted_nodes = []
for input in itertools.chain(self.initializer(), self.model.graph.input):
if input.name in deps_to_nodes:
for node_idx in deps_to_nodes[input.name]:
initializer_names = [init.name for init in self.initializer()]
graph_input_names = [input.name for input in self.model.graph.input]
input_names = initializer_names + graph_input_names
input_names.sort()
prev_input_name = None
for input_name in input_names:
if prev_input_name == input_name:
continue
prev_input_name = input_name
if input_name in deps_to_nodes:
for node_idx in deps_to_nodes[input_name]:
deps_count[node_idx] = deps_count[node_idx] - 1
if deps_count[node_idx] == 0:
sorted_nodes.append(self.nodes()[node_idx])
s = 0
e = len(sorted_nodes)
start = 0
end = len(sorted_nodes)
while s < e:
for output in sorted_nodes[s].output:
while start < end:
for output in sorted_nodes[start].output:
if output in deps_to_nodes:
for node_idx in deps_to_nodes[output]:
deps_count[node_idx] = deps_count[node_idx] - 1
if deps_count[node_idx] == 0:
sorted_nodes.append(self.nodes()[node_idx])
e = e + 1
s = s + 1
end = end + 1
start = start + 1
assert(e == len(self.graph().node)), "Graph is not a DAG"
assert(end == len(self.graph().node)), "Graph is not a DAG"
self.graph().ClearField('node')
self.graph().node.extend(sorted_nodes)
self.graph().node.extend(sorted_nodes)

View file

@ -65,6 +65,28 @@ class TestONNXModel(unittest.TestCase):
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
onnx.save(model, model_path)
def construct_model_Constant(self, model_path):
# (input) Constant
# \ /
# \ /
# \ /
# \ /
# Add
# |
# (output)
initializers = []
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 8, 12])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [4, 8, 12])
# make nodes
constant_node = onnx.helper.make_node('Constant', [], ['const_output'], value_float=42.0)
add_node = onnx.helper.make_node('Add', ['input', 'const_output'], ['output'], name='Add')
graph = helper.make_graph([add_node, constant_node],
'onnx_model_test', [input], [output], initializer=initializers)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
onnx.save(model, model_path)
def test_topo_sort(self):
test_model_path = 'onnx_model_topo_sort.onnx'
self.construct_model(test_model_path)
@ -73,5 +95,13 @@ class TestONNXModel(unittest.TestCase):
onnx_model.topological_sort()
check_op_type_order(self, onnx_model.model, ['GRU', 'Conv', 'Conv', 'Relu', 'Add'])
def test_topo_sort_constant(self):
test_model_path = 'onnx_model_topo_sort_constant.onnx'
self.construct_model_Constant(test_model_path)
onnx_model = ONNXModel(onnx.load(test_model_path))
check_op_type_order(self, onnx_model.model, ['Add', 'Constant'])
onnx_model.topological_sort()
check_op_type_order(self, onnx_model.model, ['Constant', 'Add'])
if __name__ == '__main__':
unittest.main()