mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Fix next node access bug in calibration tool (#4863)
* fix bug in calibration tool * fix next node access bugs * rm file in wrong folder * refine * optimize * refine * refine format * refine Co-authored-by: t-yguo <t-yguo@microsoft.com>
This commit is contained in:
parent
3fa73a5b6a
commit
9a6db9b9f4
1 changed files with 55 additions and 10 deletions
|
|
@ -31,7 +31,8 @@ class ONNXCalibrater:
|
|||
calibrate_op_types,
|
||||
black_nodes,
|
||||
white_nodes,
|
||||
augmented_model_path):
|
||||
augmented_model_path,
|
||||
input_name_to_nodes):
|
||||
'''
|
||||
:param model_path: ONNX model to calibrate
|
||||
:param data_reader: user implemented object to read in and preprocess calibration dataset
|
||||
|
|
@ -48,6 +49,7 @@ class ONNXCalibrater:
|
|||
self.black_nodes = black_nodes
|
||||
self.white_nodes = white_nodes
|
||||
self.augmented_model_path = augmented_model_path
|
||||
self.input_name_to_nodes = input_name_to_nodes
|
||||
|
||||
def augment_graph(self):
|
||||
'''
|
||||
|
|
@ -100,6 +102,7 @@ class ONNXCalibrater:
|
|||
|
||||
model.graph.node.extend(added_nodes)
|
||||
model.graph.output.extend(added_outputs)
|
||||
|
||||
return model
|
||||
|
||||
#Using augmented outputs to generate inputs for quantization
|
||||
|
|
@ -150,9 +153,42 @@ class ONNXCalibrater:
|
|||
raise ValueError('Unknown value for calib_mode. Currently only naive mode is supported.')
|
||||
|
||||
final_dict = dict(zip(node_names, pairs))
|
||||
|
||||
return final_dict
|
||||
|
||||
|
||||
def _get_input_name_to_nodes(self, model):
|
||||
'''
|
||||
Helper function to get input_name_to_nodes dictionary
|
||||
'''
|
||||
|
||||
for node in model.graph.node:
|
||||
for input_name in node.input:
|
||||
if input_name not in self.input_name_to_nodes:
|
||||
self.input_name_to_nodes[input_name] = [node]
|
||||
else:
|
||||
self.input_name_to_nodes[input_name].append(node)
|
||||
|
||||
|
||||
def _get_next_nodes(self, model, curr_node):
|
||||
'''
|
||||
Helper function to get child nodes for a given node
|
||||
'''
|
||||
|
||||
if not self.input_name_to_nodes:
|
||||
self._get_input_name_to_nodes(model)
|
||||
|
||||
children = []
|
||||
for output in curr_node.output:
|
||||
if output in self.input_name_to_nodes:
|
||||
for child_node in self.input_name_to_nodes[output]:
|
||||
children.append(child_node)
|
||||
|
||||
return children
|
||||
|
||||
|
||||
def calculate_scale_zeropoint(self, node, next_node, rmin, rmax):
|
||||
|
||||
zp_and_scale = []
|
||||
# adjust rmin and rmax such that 0 is included in the range. This is required
|
||||
# to make sure zero can be uniquely represented.
|
||||
|
|
@ -179,6 +215,7 @@ class ONNXCalibrater:
|
|||
|
||||
zp_and_scale.append(zero_point)
|
||||
zp_and_scale.append(scale)
|
||||
|
||||
return zp_and_scale
|
||||
|
||||
def calculate_quantization_params(self,quantization_thresholds):
|
||||
|
|
@ -206,12 +243,17 @@ class ONNXCalibrater:
|
|||
|
||||
quantization_params = {}
|
||||
model = onnx.load(self.model_path)
|
||||
for index, node in enumerate(model.graph.node):
|
||||
node_output_name = node.output[0]
|
||||
if node_output_name in quantization_thresholds:
|
||||
node_thresholds = quantization_thresholds[node_output_name]
|
||||
node_params = self.calculate_scale_zeropoint(node, model.graph.node[index + 1], node_thresholds[0],node_thresholds[1])
|
||||
quantization_params[node_output_name] = node_params
|
||||
|
||||
self._get_input_name_to_nodes(model)
|
||||
|
||||
for node in model.graph.node:
|
||||
next_nodes = self._get_next_nodes(model,node)
|
||||
for next_node in next_nodes:
|
||||
node_output_name = next_node.output[0]
|
||||
if node_output_name in quantization_thresholds:
|
||||
node_thresholds = quantization_thresholds[node_output_name]
|
||||
node_params = self.calculate_scale_zeropoint(node, next_node, node_thresholds[0], node_thresholds[1])
|
||||
quantization_params[node_output_name] = node_params
|
||||
|
||||
return quantization_params
|
||||
|
||||
|
|
@ -232,15 +274,18 @@ def calibrate(model_path,
|
|||
:param white_nodes: operator names that force to be quantized, default = ''
|
||||
:param augmented_model_path: save augmented_model to this path
|
||||
'''
|
||||
|
||||
input_name_to_nodes = {}
|
||||
|
||||
#1. initialize a calibrater
|
||||
calibrater = ONNXCalibrater(model_path,data_reader,op_types,black_nodes,white_nodes,augmented_model_path)
|
||||
calibrater = ONNXCalibrater(model_path, data_reader, op_types, black_nodes, white_nodes, augmented_model_path, input_name_to_nodes)
|
||||
#2. augment
|
||||
augmented_model = calibrater.augment_graph()
|
||||
onnx.save(augmented_model,augmented_model_path)
|
||||
onnx.save(augmented_model, augmented_model_path)
|
||||
#3. generate quantization thresholds
|
||||
dict_for_quantization = calibrater.get_intermediate_outputs()
|
||||
#4. generate quantization parameters dict
|
||||
quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization)
|
||||
|
||||
print("Calibrated,quantized parameters calculated and returned.")
|
||||
return quantization_params_dict
|
||||
return quantization_params_dict
|
||||
Loading…
Reference in a new issue