mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Fix quantization of Conv1D with bias (#5491)
* Fix reshape for Conv with bias
This commit is contained in:
parent
1038f9cc8b
commit
6c2162e97a
3 changed files with 18 additions and 26 deletions
|
|
@ -641,26 +641,27 @@ class ONNXQuantizer:
|
|||
parameter last_output: output of previous node (input to bias add)
|
||||
return: the name of output
|
||||
'''
|
||||
# Add an Add operation for bias
|
||||
# Add reshape for correct broadcase
|
||||
reshape_input = [quantized_bias_name]
|
||||
|
||||
# Add tensors for the shape to be reshaped to
|
||||
weight = find_by_name(node.input[1], self.model.initializer())
|
||||
if weight is None:
|
||||
raise ValueError("Expected {} to be an initializer".format(node.input[1]))
|
||||
|
||||
# Add reshape for correct broadcase
|
||||
reshape_input_data = quantized_bias_name
|
||||
reshape_input_shape = quantized_bias_name + "_reshape_shape"
|
||||
reshape_input = [reshape_input_data, reshape_input_shape]
|
||||
|
||||
reshape_shape = np.ones((len(weight.dims)), dtype=np.int64)
|
||||
reshape_shape[1] = -1
|
||||
init_shape = onnx.helper.make_tensor("reshape_shape", onnx_proto.TensorProto.INT64, [len(weight.dims)], reshape_shape)
|
||||
init_shape = onnx.helper.make_tensor(reshape_input_shape, onnx_proto.TensorProto.INT64, [len(weight.dims)], reshape_shape)
|
||||
self.model.add_initializer(init_shape)
|
||||
|
||||
reshape_input.append('reshape_shape')
|
||||
reshape_op_output = node.output[0] + "_reshape"
|
||||
reshape_node = onnx.helper.make_node("Reshape", reshape_input, [reshape_op_output],
|
||||
quantized_bias_name + "reshape")
|
||||
nodes.append(reshape_node)
|
||||
|
||||
# Add an Add operation for bias
|
||||
bias_add_input = [last_output]
|
||||
bias_add_input.append(reshape_op_output)
|
||||
add_node_output = node.output[0] + "_bias_add"
|
||||
|
|
|
|||
|
|
@ -90,10 +90,10 @@ class TestCalibrate(unittest.TestCase):
|
|||
augmented_model_outputs = [output.name for output in augmented_model.graph.output]
|
||||
added_node_names = ['C_ReduceMin', 'C_ReduceMax', 'D_ReduceMin', 'D_ReduceMax', 'F_ReduceMin', 'F_ReduceMax']
|
||||
added_outputs = ['C_ReduceMin', 'C_ReduceMax', 'D_ReduceMin', 'D_ReduceMax', 'F_ReduceMin', 'F_ReduceMax']
|
||||
# Original 3 nodes + added ReduceMin/Max nodes * 6 (exlude graph input/output)
|
||||
self.assertEqual(len(augmented_model_node_names), 9)
|
||||
# Original 3 nodes + added ReduceMin/Max nodes
|
||||
self.assertEqual(len(augmented_model_node_names), 15)
|
||||
# Original 1 graph output + added outputs * 6
|
||||
self.assertEqual(len(augmented_model_outputs), 7)
|
||||
self.assertEqual(len(augmented_model_outputs), 13)
|
||||
for name in added_node_names:
|
||||
self.assertTrue(name in augmented_model_node_names)
|
||||
for output in added_outputs:
|
||||
|
|
@ -131,9 +131,9 @@ class TestCalibrate(unittest.TestCase):
|
|||
added_node_names = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax']
|
||||
added_outputs = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax']
|
||||
# Original 2 nodes + added ReduceMin/Max nodes * 4
|
||||
self.assertEqual(len(augmented_model_node_names), 6)
|
||||
self.assertEqual(len(augmented_model_node_names), 12)
|
||||
# Original 1 graph output + added outputs * 4
|
||||
self.assertEqual(len(augmented_model_outputs), 5)
|
||||
self.assertEqual(len(augmented_model_outputs), 11)
|
||||
for name in added_node_names:
|
||||
self.assertTrue(name in augmented_model_node_names)
|
||||
for output in added_outputs:
|
||||
|
|
@ -175,10 +175,10 @@ class TestCalibrate(unittest.TestCase):
|
|||
augmented_model_outputs = [output.name for output in augmented_model.graph.output]
|
||||
added_node_names = ['M_ReduceMin', 'M_ReduceMax', 'O_ReduceMin', 'O_ReduceMax', 'P_ReduceMin', 'P_ReduceMax', 'Q_ReduceMin', 'Q_ReduceMax']
|
||||
added_outputs = ['M_ReduceMin', 'M_ReduceMax', 'O_ReduceMin', 'O_ReduceMax', 'P_ReduceMin', 'P_ReduceMax', 'Q_ReduceMin', 'Q_ReduceMax']
|
||||
# Original 4 nodes + added ReduceMin/Max nodes * 8
|
||||
self.assertEqual(len(augmented_model_node_names), 12)
|
||||
# Original 1 graph output + added outputs * 8
|
||||
self.assertEqual(len(augmented_model_outputs), 9)
|
||||
# Original 4 nodes + added ReduceMin/Max nodes
|
||||
self.assertEqual(len(augmented_model_node_names), 14)
|
||||
# Original 1 graph output + added outputs
|
||||
self.assertEqual(len(augmented_model_outputs), 11)
|
||||
for name in added_node_names:
|
||||
self.assertTrue(name in augmented_model_node_names)
|
||||
for output in added_outputs:
|
||||
|
|
@ -242,7 +242,7 @@ class TestCalibrate(unittest.TestCase):
|
|||
quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization)
|
||||
|
||||
#check the size of the quantization dictionary
|
||||
self.assertEqual(len(quantization_params_dict), 5)
|
||||
self.assertEqual(len(quantization_params_dict), 11)
|
||||
|
||||
#check the computation of zp and scale
|
||||
for key, value in quantization_params_dict.items():
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ from pathlib import Path
|
|||
import unittest
|
||||
import urllib.request
|
||||
|
||||
from onnxruntime.quantization.quantize import optimize_model, ONNXQuantizer
|
||||
from onnxruntime.quantization.onnx_model import ONNXModel
|
||||
from onnxruntime.quantization.quantize import ONNXQuantizer
|
||||
|
||||
from onnxruntime.quantization.quant_utils import QuantizationMode
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
|
@ -83,8 +82,6 @@ def generate_qat_model(model_names):
|
|||
|
||||
model_1 = onnx.helper.make_model(graph)
|
||||
model_1.ir_version = onnx.IR_VERSION
|
||||
opset = model_1.opset_import.add()
|
||||
opset.version = 11
|
||||
onnx.save(model_1, model_names[0])
|
||||
|
||||
test_models.extend([model_1])
|
||||
|
|
@ -157,8 +154,6 @@ def generate_qat_model(model_names):
|
|||
|
||||
model_2 = onnx.helper.make_model(graph)
|
||||
model_2.ir_version = onnx.IR_VERSION
|
||||
opset = model_2.opset_import.add()
|
||||
opset.version = 11
|
||||
onnx.save(model_2, model_names[1])
|
||||
|
||||
test_models.extend([model_2])
|
||||
|
|
@ -208,8 +203,6 @@ def generate_qat_support_model(model_names, test_initializers):
|
|||
|
||||
model_1 = onnx.ModelProto()
|
||||
model_1.ir_version = onnx.IR_VERSION
|
||||
opset = model_1.opset_import.add()
|
||||
opset.version = 11
|
||||
model_1 = onnx.helper.make_model(graph)
|
||||
onnx.save(model_1, model_names[0])
|
||||
|
||||
|
|
@ -251,8 +244,6 @@ def generate_qat_support_model(model_names, test_initializers):
|
|||
|
||||
model_2 = onnx.ModelProto()
|
||||
model_2.ir_version = onnx.IR_VERSION
|
||||
opset = model_2.opset_import.add()
|
||||
opset.version = 11
|
||||
model_2 = onnx.helper.make_model(graph)
|
||||
onnx.save(model_1, model_names[1])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue