From 4bb0e29d0ea9565f2118fa0686806de6a595c98b Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Tue, 22 Jun 2021 15:08:54 -0700 Subject: [PATCH] initialize generated_value_names with graph input (#8085) * initialize generated_value_names with graph input * use set for following usage --- onnxruntime/python/tools/quantization/onnx_model.py | 11 +++++++++++ .../python/tools/quantization/onnx_quantizer.py | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index fefacd1fac..0c888abb76 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -47,6 +47,9 @@ class ONNXModel: return tensor return None + def get_initializer_name_set(self): + return set(initializer.name for initializer in self.model.graph.initializer) + def remove_initializer(self, tensor): if tensor in self.model.graph.initializer: self.model.graph.initializer.remove(tensor) @@ -59,6 +62,14 @@ class ONNXModel: for initializer in init_to_remove: self.remove_initializer(initializer) + def get_non_initializer_inputs(self): + initializer_names = self.get_initializer_name_set() + non_initializer_inputs = set() + for input in self.model.graph.input: + if input.name not in initializer_names: + non_initializer_inputs.add(input.name) + return non_initializer_inputs + def input_name_to_nodes(self): input_name_to_nodes = {} for node in self.model.graph.node: diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index f0a9def76f..c42398d6c5 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -83,7 +83,7 @@ class ONNXQuantizer: self.quantized_value_map = {} # some output from nodes will be quantized, yet itself should be treat as existing so # no dequantized will be applied when needed later - self.generated_value_names = {} + self.generated_value_names = self.model.get_non_initializer_inputs() def check_opset_version(self): ai_onnx_domain = [ @@ -208,7 +208,7 @@ class ONNXQuantizer: op_quantizer.quantize() for i in range(number_of_existing_new_nodes, len(self.new_nodes)): for output_name in self.new_nodes[i].output: - self.generated_value_names.update({output_name : 1}) + self.generated_value_names.add(output_name) self._dequantize_outputs()