diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 1aef610c28..4da4cb75ef 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -31,7 +31,8 @@ class ONNXCalibrater: calibrate_op_types, black_nodes, white_nodes, - augmented_model_path): + augmented_model_path, + input_name_to_nodes): ''' :param model_path: ONNX model to calibrate :param data_reader: user implemented object to read in and preprocess calibration dataset @@ -48,6 +49,7 @@ class ONNXCalibrater: self.black_nodes = black_nodes self.white_nodes = white_nodes self.augmented_model_path = augmented_model_path + self.input_name_to_nodes = input_name_to_nodes def augment_graph(self): ''' @@ -100,6 +102,7 @@ class ONNXCalibrater: model.graph.node.extend(added_nodes) model.graph.output.extend(added_outputs) + return model #Using augmented outputs to generate inputs for quantization @@ -150,9 +153,42 @@ class ONNXCalibrater: raise ValueError('Unknown value for calib_mode. Currently only naive mode is supported.') final_dict = dict(zip(node_names, pairs)) + return final_dict + + + def _get_input_name_to_nodes(self, model): + ''' + Helper function to get input_name_to_nodes dictionary + ''' + + for node in model.graph.node: + for input_name in node.input: + if input_name not in self.input_name_to_nodes: + self.input_name_to_nodes[input_name] = [node] + else: + self.input_name_to_nodes[input_name].append(node) + + + def _get_next_nodes(self, model, curr_node): + ''' + Helper function to get child nodes for a given node + ''' + + if not self.input_name_to_nodes: + self._get_input_name_to_nodes(model) + + children = [] + for output in curr_node.output: + if output in self.input_name_to_nodes: + for child_node in self.input_name_to_nodes[output]: + children.append(child_node) + + return children + def calculate_scale_zeropoint(self, node, next_node, rmin, rmax): + zp_and_scale = [] # adjust rmin and rmax such that 0 is included in the range. This is required # to make sure zero can be uniquely represented. @@ -179,6 +215,7 @@ class ONNXCalibrater: zp_and_scale.append(zero_point) zp_and_scale.append(scale) + return zp_and_scale def calculate_quantization_params(self,quantization_thresholds): @@ -206,12 +243,17 @@ class ONNXCalibrater: quantization_params = {} model = onnx.load(self.model_path) - for index, node in enumerate(model.graph.node): - node_output_name = node.output[0] - if node_output_name in quantization_thresholds: - node_thresholds = quantization_thresholds[node_output_name] - node_params = self.calculate_scale_zeropoint(node, model.graph.node[index + 1], node_thresholds[0],node_thresholds[1]) - quantization_params[node_output_name] = node_params + + self._get_input_name_to_nodes(model) + + for node in model.graph.node: + next_nodes = self._get_next_nodes(model,node) + for next_node in next_nodes: + node_output_name = next_node.output[0] + if node_output_name in quantization_thresholds: + node_thresholds = quantization_thresholds[node_output_name] + node_params = self.calculate_scale_zeropoint(node, next_node, node_thresholds[0], node_thresholds[1]) + quantization_params[node_output_name] = node_params return quantization_params @@ -232,15 +274,18 @@ def calibrate(model_path, :param white_nodes: operator names that force to be quantized, default = '' :param augmented_model_path: save augmented_model to this path ''' + + input_name_to_nodes = {} + #1. initialize a calibrater - calibrater = ONNXCalibrater(model_path,data_reader,op_types,black_nodes,white_nodes,augmented_model_path) + calibrater = ONNXCalibrater(model_path, data_reader, op_types, black_nodes, white_nodes, augmented_model_path, input_name_to_nodes) #2. augment augmented_model = calibrater.augment_graph() - onnx.save(augmented_model,augmented_model_path) + onnx.save(augmented_model, augmented_model_path) #3. generate quantization thresholds dict_for_quantization = calibrater.get_intermediate_outputs() #4. generate quantization parameters dict quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization) print("Calibrated,quantized parameters calculated and returned.") - return quantization_params_dict + return quantization_params_dict \ No newline at end of file