diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 0cda0a4a59..8e6d70c4bb 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -260,38 +260,49 @@ class ONNXModel: def topological_sort(self): deps_count = [0]*len(self.nodes()) # dependency count of each node deps_to_nodes = {} # input to node indice + sorted_nodes = [] # initialize sorted_nodes for node_idx, node in enumerate(self.nodes()): # CANNOT use len(node.input) directly because input can be optional deps_count[node_idx] = sum(1 for _ in node.input if _ ) + if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs + sorted_nodes.append(self.nodes()[node_idx]) + continue + for input_name in node.input: if input_name not in deps_to_nodes: deps_to_nodes[input_name] = [node_idx] else: deps_to_nodes[input_name].append(node_idx) - # initialize sorted_nodes - sorted_nodes = [] - for input in itertools.chain(self.initializer(), self.model.graph.input): - if input.name in deps_to_nodes: - for node_idx in deps_to_nodes[input.name]: + initializer_names = [init.name for init in self.initializer()] + graph_input_names = [input.name for input in self.model.graph.input] + input_names = initializer_names + graph_input_names + input_names.sort() + prev_input_name = None + for input_name in input_names: + if prev_input_name == input_name: + continue + + prev_input_name = input_name + if input_name in deps_to_nodes: + for node_idx in deps_to_nodes[input_name]: deps_count[node_idx] = deps_count[node_idx] - 1 if deps_count[node_idx] == 0: sorted_nodes.append(self.nodes()[node_idx]) - s = 0 - e = len(sorted_nodes) + start = 0 + end = len(sorted_nodes) - while s < e: - for output in sorted_nodes[s].output: + while start < end: + for output in sorted_nodes[start].output: if output in deps_to_nodes: for node_idx in deps_to_nodes[output]: deps_count[node_idx] = deps_count[node_idx] - 1 if deps_count[node_idx] == 0: sorted_nodes.append(self.nodes()[node_idx]) - e = e + 1 - s = s + 1 + end = end + 1 + start = start + 1 - assert(e == len(self.graph().node)), "Graph is not a DAG" + assert(end == len(self.graph().node)), "Graph is not a DAG" self.graph().ClearField('node') - self.graph().node.extend(sorted_nodes) - + self.graph().node.extend(sorted_nodes) \ No newline at end of file diff --git a/onnxruntime/test/python/quantization/test_onnx_model.py b/onnxruntime/test/python/quantization/test_onnx_model.py index 7d98b53b2e..b1d1736639 100644 --- a/onnxruntime/test/python/quantization/test_onnx_model.py +++ b/onnxruntime/test/python/quantization/test_onnx_model.py @@ -65,6 +65,28 @@ class TestONNXModel(unittest.TestCase): model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) onnx.save(model, model_path) + def construct_model_Constant(self, model_path): + # (input) Constant + # \ / + # \ / + # \ / + # \ / + # Add + # | + # (output) + + initializers = [] + input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 8, 12]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [4, 8, 12]) + + # make nodes + constant_node = onnx.helper.make_node('Constant', [], ['const_output'], value_float=42.0) + add_node = onnx.helper.make_node('Add', ['input', 'const_output'], ['output'], name='Add') + graph = helper.make_graph([add_node, constant_node], + 'onnx_model_test', [input], [output], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + onnx.save(model, model_path) + def test_topo_sort(self): test_model_path = 'onnx_model_topo_sort.onnx' self.construct_model(test_model_path) @@ -73,5 +95,13 @@ class TestONNXModel(unittest.TestCase): onnx_model.topological_sort() check_op_type_order(self, onnx_model.model, ['GRU', 'Conv', 'Conv', 'Relu', 'Add']) + def test_topo_sort_constant(self): + test_model_path = 'onnx_model_topo_sort_constant.onnx' + self.construct_model_Constant(test_model_path) + onnx_model = ONNXModel(onnx.load(test_model_path)) + check_op_type_order(self, onnx_model.model, ['Add', 'Constant']) + onnx_model.topological_sort() + check_op_type_order(self, onnx_model.model, ['Constant', 'Add']) + if __name__ == '__main__': unittest.main()