mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
initialize generated_value_names with graph input (#8085)
* initialize generated_value_names with graph input * use set for following usage
This commit is contained in:
parent
839f69d249
commit
4bb0e29d0e
2 changed files with 13 additions and 2 deletions
|
|
@ -47,6 +47,9 @@ class ONNXModel:
|
|||
return tensor
|
||||
return None
|
||||
|
||||
def get_initializer_name_set(self):
|
||||
return set(initializer.name for initializer in self.model.graph.initializer)
|
||||
|
||||
def remove_initializer(self, tensor):
|
||||
if tensor in self.model.graph.initializer:
|
||||
self.model.graph.initializer.remove(tensor)
|
||||
|
|
@ -59,6 +62,14 @@ class ONNXModel:
|
|||
for initializer in init_to_remove:
|
||||
self.remove_initializer(initializer)
|
||||
|
||||
def get_non_initializer_inputs(self):
|
||||
initializer_names = self.get_initializer_name_set()
|
||||
non_initializer_inputs = set()
|
||||
for input in self.model.graph.input:
|
||||
if input.name not in initializer_names:
|
||||
non_initializer_inputs.add(input.name)
|
||||
return non_initializer_inputs
|
||||
|
||||
def input_name_to_nodes(self):
|
||||
input_name_to_nodes = {}
|
||||
for node in self.model.graph.node:
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class ONNXQuantizer:
|
|||
self.quantized_value_map = {}
|
||||
# some output from nodes will be quantized, yet itself should be treat as existing so
|
||||
# no dequantized will be applied when needed later
|
||||
self.generated_value_names = {}
|
||||
self.generated_value_names = self.model.get_non_initializer_inputs()
|
||||
|
||||
def check_opset_version(self):
|
||||
ai_onnx_domain = [
|
||||
|
|
@ -208,7 +208,7 @@ class ONNXQuantizer:
|
|||
op_quantizer.quantize()
|
||||
for i in range(number_of_existing_new_nodes, len(self.new_nodes)):
|
||||
for output_name in self.new_nodes[i].output:
|
||||
self.generated_value_names.update({output_name : 1})
|
||||
self.generated_value_names.add(output_name)
|
||||
|
||||
self._dequantize_outputs()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue