Fix quantization of Conv1D with bias (#5491)

* Fix reshape for Conv with bias
This commit is contained in:
Yufeng Li 2020-10-20 15:27:26 -07:00 committed by GitHub
parent 1038f9cc8b
commit 6c2162e97a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 26 deletions

View file

@ -641,26 +641,27 @@ class ONNXQuantizer:
parameter last_output: output of previous node (input to bias add)
return: the name of output
'''
# Add an Add operation for bias
# Add reshape for correct broadcase
reshape_input = [quantized_bias_name]
# Add tensors for the shape to be reshaped to
weight = find_by_name(node.input[1], self.model.initializer())
if weight is None:
raise ValueError("Expected {} to be an initializer".format(node.input[1]))
# Add reshape for correct broadcase
reshape_input_data = quantized_bias_name
reshape_input_shape = quantized_bias_name + "_reshape_shape"
reshape_input = [reshape_input_data, reshape_input_shape]
reshape_shape = np.ones((len(weight.dims)), dtype=np.int64)
reshape_shape[1] = -1
init_shape = onnx.helper.make_tensor("reshape_shape", onnx_proto.TensorProto.INT64, [len(weight.dims)], reshape_shape)
init_shape = onnx.helper.make_tensor(reshape_input_shape, onnx_proto.TensorProto.INT64, [len(weight.dims)], reshape_shape)
self.model.add_initializer(init_shape)
reshape_input.append('reshape_shape')
reshape_op_output = node.output[0] + "_reshape"
reshape_node = onnx.helper.make_node("Reshape", reshape_input, [reshape_op_output],
quantized_bias_name + "reshape")
nodes.append(reshape_node)
# Add an Add operation for bias
bias_add_input = [last_output]
bias_add_input.append(reshape_op_output)
add_node_output = node.output[0] + "_bias_add"

View file

@ -90,10 +90,10 @@ class TestCalibrate(unittest.TestCase):
augmented_model_outputs = [output.name for output in augmented_model.graph.output]
added_node_names = ['C_ReduceMin', 'C_ReduceMax', 'D_ReduceMin', 'D_ReduceMax', 'F_ReduceMin', 'F_ReduceMax']
added_outputs = ['C_ReduceMin', 'C_ReduceMax', 'D_ReduceMin', 'D_ReduceMax', 'F_ReduceMin', 'F_ReduceMax']
# Original 3 nodes + added ReduceMin/Max nodes * 6 (exlude graph input/output)
self.assertEqual(len(augmented_model_node_names), 9)
# Original 3 nodes + added ReduceMin/Max nodes
self.assertEqual(len(augmented_model_node_names), 15)
# Original 1 graph output + added outputs * 6
self.assertEqual(len(augmented_model_outputs), 7)
self.assertEqual(len(augmented_model_outputs), 13)
for name in added_node_names:
self.assertTrue(name in augmented_model_node_names)
for output in added_outputs:
@ -131,9 +131,9 @@ class TestCalibrate(unittest.TestCase):
added_node_names = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax']
added_outputs = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax']
# Original 2 nodes + added ReduceMin/Max nodes * 4
self.assertEqual(len(augmented_model_node_names), 6)
self.assertEqual(len(augmented_model_node_names), 12)
# Original 1 graph output + added outputs * 4
self.assertEqual(len(augmented_model_outputs), 5)
self.assertEqual(len(augmented_model_outputs), 11)
for name in added_node_names:
self.assertTrue(name in augmented_model_node_names)
for output in added_outputs:
@ -175,10 +175,10 @@ class TestCalibrate(unittest.TestCase):
augmented_model_outputs = [output.name for output in augmented_model.graph.output]
added_node_names = ['M_ReduceMin', 'M_ReduceMax', 'O_ReduceMin', 'O_ReduceMax', 'P_ReduceMin', 'P_ReduceMax', 'Q_ReduceMin', 'Q_ReduceMax']
added_outputs = ['M_ReduceMin', 'M_ReduceMax', 'O_ReduceMin', 'O_ReduceMax', 'P_ReduceMin', 'P_ReduceMax', 'Q_ReduceMin', 'Q_ReduceMax']
# Original 4 nodes + added ReduceMin/Max nodes * 8
self.assertEqual(len(augmented_model_node_names), 12)
# Original 1 graph output + added outputs * 8
self.assertEqual(len(augmented_model_outputs), 9)
# Original 4 nodes + added ReduceMin/Max nodes
self.assertEqual(len(augmented_model_node_names), 14)
# Original 1 graph output + added outputs
self.assertEqual(len(augmented_model_outputs), 11)
for name in added_node_names:
self.assertTrue(name in augmented_model_node_names)
for output in added_outputs:
@ -242,7 +242,7 @@ class TestCalibrate(unittest.TestCase):
quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization)
#check the size of the quantization dictionary
self.assertEqual(len(quantization_params_dict), 5)
self.assertEqual(len(quantization_params_dict), 11)
#check the computation of zp and scale
for key, value in quantization_params_dict.items():

View file

@ -12,8 +12,7 @@ from pathlib import Path
import unittest
import urllib.request
from onnxruntime.quantization.quantize import optimize_model, ONNXQuantizer
from onnxruntime.quantization.onnx_model import ONNXModel
from onnxruntime.quantization.quantize import ONNXQuantizer
from onnxruntime.quantization.quant_utils import QuantizationMode
from onnx import onnx_pb as onnx_proto
@ -83,8 +82,6 @@ def generate_qat_model(model_names):
model_1 = onnx.helper.make_model(graph)
model_1.ir_version = onnx.IR_VERSION
opset = model_1.opset_import.add()
opset.version = 11
onnx.save(model_1, model_names[0])
test_models.extend([model_1])
@ -157,8 +154,6 @@ def generate_qat_model(model_names):
model_2 = onnx.helper.make_model(graph)
model_2.ir_version = onnx.IR_VERSION
opset = model_2.opset_import.add()
opset.version = 11
onnx.save(model_2, model_names[1])
test_models.extend([model_2])
@ -208,8 +203,6 @@ def generate_qat_support_model(model_names, test_initializers):
model_1 = onnx.ModelProto()
model_1.ir_version = onnx.IR_VERSION
opset = model_1.opset_import.add()
opset.version = 11
model_1 = onnx.helper.make_model(graph)
onnx.save(model_1, model_names[0])
@ -251,8 +244,6 @@ def generate_qat_support_model(model_names, test_initializers):
model_2 = onnx.ModelProto()
model_2.ir_version = onnx.IR_VERSION
opset = model_2.opset_import.add()
opset.version = 11
model_2 = onnx.helper.make_model(graph)
onnx.save(model_1, model_names[1])