diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 06dbce2b29..5b5778439c 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -641,26 +641,27 @@ class ONNXQuantizer: parameter last_output: output of previous node (input to bias add) return: the name of output ''' - # Add an Add operation for bias - # Add reshape for correct broadcase - reshape_input = [quantized_bias_name] - # Add tensors for the shape to be reshaped to weight = find_by_name(node.input[1], self.model.initializer()) if weight is None: raise ValueError("Expected {} to be an initializer".format(node.input[1])) + # Add reshape for correct broadcase + reshape_input_data = quantized_bias_name + reshape_input_shape = quantized_bias_name + "_reshape_shape" + reshape_input = [reshape_input_data, reshape_input_shape] + reshape_shape = np.ones((len(weight.dims)), dtype=np.int64) reshape_shape[1] = -1 - init_shape = onnx.helper.make_tensor("reshape_shape", onnx_proto.TensorProto.INT64, [len(weight.dims)], reshape_shape) + init_shape = onnx.helper.make_tensor(reshape_input_shape, onnx_proto.TensorProto.INT64, [len(weight.dims)], reshape_shape) self.model.add_initializer(init_shape) - reshape_input.append('reshape_shape') reshape_op_output = node.output[0] + "_reshape" reshape_node = onnx.helper.make_node("Reshape", reshape_input, [reshape_op_output], quantized_bias_name + "reshape") nodes.append(reshape_node) + # Add an Add operation for bias bias_add_input = [last_output] bias_add_input.append(reshape_op_output) add_node_output = node.output[0] + "_bias_add" diff --git a/onnxruntime/python/tools/quantization/test/test_calibrate.py b/onnxruntime/python/tools/quantization/test/test_calibrate.py index 374206e89e..fdba7182d2 100644 --- a/onnxruntime/python/tools/quantization/test/test_calibrate.py +++ b/onnxruntime/python/tools/quantization/test/test_calibrate.py @@ -90,10 +90,10 @@ class TestCalibrate(unittest.TestCase): augmented_model_outputs = [output.name for output in augmented_model.graph.output] added_node_names = ['C_ReduceMin', 'C_ReduceMax', 'D_ReduceMin', 'D_ReduceMax', 'F_ReduceMin', 'F_ReduceMax'] added_outputs = ['C_ReduceMin', 'C_ReduceMax', 'D_ReduceMin', 'D_ReduceMax', 'F_ReduceMin', 'F_ReduceMax'] - # Original 3 nodes + added ReduceMin/Max nodes * 6 (exlude graph input/output) - self.assertEqual(len(augmented_model_node_names), 9) + # Original 3 nodes + added ReduceMin/Max nodes + self.assertEqual(len(augmented_model_node_names), 15) # Original 1 graph output + added outputs * 6 - self.assertEqual(len(augmented_model_outputs), 7) + self.assertEqual(len(augmented_model_outputs), 13) for name in added_node_names: self.assertTrue(name in augmented_model_node_names) for output in added_outputs: @@ -131,9 +131,9 @@ class TestCalibrate(unittest.TestCase): added_node_names = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax'] added_outputs = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax'] # Original 2 nodes + added ReduceMin/Max nodes * 4 - self.assertEqual(len(augmented_model_node_names), 6) + self.assertEqual(len(augmented_model_node_names), 12) # Original 1 graph output + added outputs * 4 - self.assertEqual(len(augmented_model_outputs), 5) + self.assertEqual(len(augmented_model_outputs), 11) for name in added_node_names: self.assertTrue(name in augmented_model_node_names) for output in added_outputs: @@ -175,10 +175,10 @@ class TestCalibrate(unittest.TestCase): augmented_model_outputs = [output.name for output in augmented_model.graph.output] added_node_names = ['M_ReduceMin', 'M_ReduceMax', 'O_ReduceMin', 'O_ReduceMax', 'P_ReduceMin', 'P_ReduceMax', 'Q_ReduceMin', 'Q_ReduceMax'] added_outputs = ['M_ReduceMin', 'M_ReduceMax', 'O_ReduceMin', 'O_ReduceMax', 'P_ReduceMin', 'P_ReduceMax', 'Q_ReduceMin', 'Q_ReduceMax'] - # Original 4 nodes + added ReduceMin/Max nodes * 8 - self.assertEqual(len(augmented_model_node_names), 12) - # Original 1 graph output + added outputs * 8 - self.assertEqual(len(augmented_model_outputs), 9) + # Original 4 nodes + added ReduceMin/Max nodes + self.assertEqual(len(augmented_model_node_names), 14) + # Original 1 graph output + added outputs + self.assertEqual(len(augmented_model_outputs), 11) for name in added_node_names: self.assertTrue(name in augmented_model_node_names) for output in added_outputs: @@ -242,7 +242,7 @@ class TestCalibrate(unittest.TestCase): quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization) #check the size of the quantization dictionary - self.assertEqual(len(quantization_params_dict), 5) + self.assertEqual(len(quantization_params_dict), 11) #check the computation of zp and scale for key, value in quantization_params_dict.items(): diff --git a/onnxruntime/python/tools/quantization/test/test_qat_support.py b/onnxruntime/python/tools/quantization/test/test_qat_support.py index b190c7a597..e01fd23cd2 100644 --- a/onnxruntime/python/tools/quantization/test/test_qat_support.py +++ b/onnxruntime/python/tools/quantization/test/test_qat_support.py @@ -12,8 +12,7 @@ from pathlib import Path import unittest import urllib.request -from onnxruntime.quantization.quantize import optimize_model, ONNXQuantizer -from onnxruntime.quantization.onnx_model import ONNXModel +from onnxruntime.quantization.quantize import ONNXQuantizer from onnxruntime.quantization.quant_utils import QuantizationMode from onnx import onnx_pb as onnx_proto @@ -83,8 +82,6 @@ def generate_qat_model(model_names): model_1 = onnx.helper.make_model(graph) model_1.ir_version = onnx.IR_VERSION - opset = model_1.opset_import.add() - opset.version = 11 onnx.save(model_1, model_names[0]) test_models.extend([model_1]) @@ -157,8 +154,6 @@ def generate_qat_model(model_names): model_2 = onnx.helper.make_model(graph) model_2.ir_version = onnx.IR_VERSION - opset = model_2.opset_import.add() - opset.version = 11 onnx.save(model_2, model_names[1]) test_models.extend([model_2]) @@ -208,8 +203,6 @@ def generate_qat_support_model(model_names, test_initializers): model_1 = onnx.ModelProto() model_1.ir_version = onnx.IR_VERSION - opset = model_1.opset_import.add() - opset.version = 11 model_1 = onnx.helper.make_model(graph) onnx.save(model_1, model_names[0]) @@ -251,8 +244,6 @@ def generate_qat_support_model(model_names, test_initializers): model_2 = onnx.ModelProto() model_2.ir_version = onnx.IR_VERSION - opset = model_2.opset_import.add() - opset.version = 11 model_2 = onnx.helper.make_model(graph) onnx.save(model_1, model_names[1])