initialize generated_value_names with graph input (#8085)

* initialize generated_value_names with graph input
* use set for following usage
This commit is contained in:
Yufeng Li 2021-06-22 15:08:54 -07:00 committed by GitHub
parent 839f69d249
commit 4bb0e29d0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 2 deletions

View file

@ -47,6 +47,9 @@ class ONNXModel:
return tensor
return None
def get_initializer_name_set(self):
return set(initializer.name for initializer in self.model.graph.initializer)
def remove_initializer(self, tensor):
if tensor in self.model.graph.initializer:
self.model.graph.initializer.remove(tensor)
@ -59,6 +62,14 @@ class ONNXModel:
for initializer in init_to_remove:
self.remove_initializer(initializer)
def get_non_initializer_inputs(self):
initializer_names = self.get_initializer_name_set()
non_initializer_inputs = set()
for input in self.model.graph.input:
if input.name not in initializer_names:
non_initializer_inputs.add(input.name)
return non_initializer_inputs
def input_name_to_nodes(self):
input_name_to_nodes = {}
for node in self.model.graph.node:

View file

@ -83,7 +83,7 @@ class ONNXQuantizer:
self.quantized_value_map = {}
# some output from nodes will be quantized, yet itself should be treat as existing so
# no dequantized will be applied when needed later
self.generated_value_names = {}
self.generated_value_names = self.model.get_non_initializer_inputs()
def check_opset_version(self):
ai_onnx_domain = [
@ -208,7 +208,7 @@ class ONNXQuantizer:
op_quantizer.quantize()
for i in range(number_of_existing_new_nodes, len(self.new_nodes)):
for output_name in self.new_nodes[i].output:
self.generated_value_names.update({output_name : 1})
self.generated_value_names.add(output_name)
self._dequantize_outputs()