mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[Quant tool] Fix quantized bias's scale dtype to properly handle fp16 bias inputs (#20340)
### Description - Fix quantization tool bug that did not correctly set a quantized bias's scale data type to fp16 if the original bias was fp16. - Enabled fp16 ConvTranspose quantization unit tests that were disabled. ### Motivation and Context Python quantization tests for fp16 ConvTranspose were originally disabled due to a shape inference bug. It turns out that we also have a bug in our quantizer that does not properly handle fp16 bias inputs. Fixing the bug allows us to re-enable these tests with the latest version of ONNX.
This commit is contained in:
parent
0a1902525f
commit
eae7b705ac
2 changed files with 4 additions and 4 deletions
|
|
@ -235,7 +235,9 @@ class BaseQuantizer:
|
|||
bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims)
|
||||
packed_bias_initializer = onnx.numpy_helper.from_array(bias_np_data, quantized_bias_name)
|
||||
self.model.initializer_extend([packed_bias_initializer])
|
||||
bias_scale_data = np.asarray(bias_scale, dtype=np.float32).reshape(-1)
|
||||
|
||||
# Bias's scale dtype should match the original bias data's unquantized type (float32 or float16).
|
||||
bias_scale_data = np.asarray(bias_scale, dtype=bias_data.dtype).reshape(-1)
|
||||
node_type = "DequantizeLinear"
|
||||
node_qtype = self.weight_qType
|
||||
|
||||
|
|
|
|||
|
|
@ -149,7 +149,6 @@ class TestOpConvTranspose(unittest.TestCase):
|
|||
def test_quantize_conv_transpose_u8u8(self):
|
||||
self.quantize_conv_transpose_u8u8(TensorProto.FLOAT, 13, 7)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
|
||||
def test_quantize_conv_transpose_u8u8_fp16(self):
|
||||
self.quantize_conv_transpose_u8u8(TensorProto.FLOAT16, 19, 9)
|
||||
|
||||
|
|
@ -160,7 +159,7 @@ class TestOpConvTranspose(unittest.TestCase):
|
|||
|
||||
np.random.seed(1)
|
||||
model_fp32_path = "conv_transpose_fp32.onnx"
|
||||
self.construct_model(model_fp32_path)
|
||||
self.construct_model(model_fp32_path, onnx_type, opset, ir_version)
|
||||
dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_type)
|
||||
data_reader = self.input_feeds(1, {"input": [1, 1, 7, 7]}, dtype)
|
||||
|
||||
|
|
@ -175,7 +174,6 @@ class TestOpConvTranspose(unittest.TestCase):
|
|||
def test_quantize_conv_transpose_s8s8(self):
|
||||
self.quantize_conv_transpose_s8s8(TensorProto.FLOAT, 13, 7)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
|
||||
def test_quantize_conv_transpose_s8s8_fp16(self):
|
||||
self.quantize_conv_transpose_s8s8(TensorProto.FLOAT16, 19, 9)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue