Fix next node access bug in calibration tool (#4863)

* fix bug in calibration tool

* fix next node access bugs

* rm file in wrong folder

* refine

* optimize

* refine

* refine format

* refine

Co-authored-by: t-yguo <t-yguo@microsoft.com>
This commit is contained in:
RRRachelllll555 2020-08-21 20:48:54 -07:00 committed by GitHub
parent 3fa73a5b6a
commit 9a6db9b9f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -31,7 +31,8 @@ class ONNXCalibrater:
calibrate_op_types,
black_nodes,
white_nodes,
augmented_model_path):
augmented_model_path,
input_name_to_nodes):
'''
:param model_path: ONNX model to calibrate
:param data_reader: user implemented object to read in and preprocess calibration dataset
@ -48,6 +49,7 @@ class ONNXCalibrater:
self.black_nodes = black_nodes
self.white_nodes = white_nodes
self.augmented_model_path = augmented_model_path
self.input_name_to_nodes = input_name_to_nodes
def augment_graph(self):
'''
@ -100,6 +102,7 @@ class ONNXCalibrater:
model.graph.node.extend(added_nodes)
model.graph.output.extend(added_outputs)
return model
#Using augmented outputs to generate inputs for quantization
@ -150,9 +153,42 @@ class ONNXCalibrater:
raise ValueError('Unknown value for calib_mode. Currently only naive mode is supported.')
final_dict = dict(zip(node_names, pairs))
return final_dict
def _get_input_name_to_nodes(self, model):
'''
Helper function to get input_name_to_nodes dictionary
'''
for node in model.graph.node:
for input_name in node.input:
if input_name not in self.input_name_to_nodes:
self.input_name_to_nodes[input_name] = [node]
else:
self.input_name_to_nodes[input_name].append(node)
def _get_next_nodes(self, model, curr_node):
'''
Helper function to get child nodes for a given node
'''
if not self.input_name_to_nodes:
self._get_input_name_to_nodes(model)
children = []
for output in curr_node.output:
if output in self.input_name_to_nodes:
for child_node in self.input_name_to_nodes[output]:
children.append(child_node)
return children
def calculate_scale_zeropoint(self, node, next_node, rmin, rmax):
zp_and_scale = []
# adjust rmin and rmax such that 0 is included in the range. This is required
# to make sure zero can be uniquely represented.
@ -179,6 +215,7 @@ class ONNXCalibrater:
zp_and_scale.append(zero_point)
zp_and_scale.append(scale)
return zp_and_scale
def calculate_quantization_params(self,quantization_thresholds):
@ -206,12 +243,17 @@ class ONNXCalibrater:
quantization_params = {}
model = onnx.load(self.model_path)
for index, node in enumerate(model.graph.node):
node_output_name = node.output[0]
if node_output_name in quantization_thresholds:
node_thresholds = quantization_thresholds[node_output_name]
node_params = self.calculate_scale_zeropoint(node, model.graph.node[index + 1], node_thresholds[0],node_thresholds[1])
quantization_params[node_output_name] = node_params
self._get_input_name_to_nodes(model)
for node in model.graph.node:
next_nodes = self._get_next_nodes(model,node)
for next_node in next_nodes:
node_output_name = next_node.output[0]
if node_output_name in quantization_thresholds:
node_thresholds = quantization_thresholds[node_output_name]
node_params = self.calculate_scale_zeropoint(node, next_node, node_thresholds[0], node_thresholds[1])
quantization_params[node_output_name] = node_params
return quantization_params
@ -232,15 +274,18 @@ def calibrate(model_path,
:param white_nodes: operator names that force to be quantized, default = ''
:param augmented_model_path: save augmented_model to this path
'''
input_name_to_nodes = {}
#1. initialize a calibrater
calibrater = ONNXCalibrater(model_path,data_reader,op_types,black_nodes,white_nodes,augmented_model_path)
calibrater = ONNXCalibrater(model_path, data_reader, op_types, black_nodes, white_nodes, augmented_model_path, input_name_to_nodes)
#2. augment
augmented_model = calibrater.augment_graph()
onnx.save(augmented_model,augmented_model_path)
onnx.save(augmented_model, augmented_model_path)
#3. generate quantization thresholds
dict_for_quantization = calibrater.get_intermediate_outputs()
#4. generate quantization parameters dict
quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization)
print("Calibrated,quantized parameters calculated and returned.")
return quantization_params_dict
return quantization_params_dict