diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 60bf90c243..b71f332252 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -989,8 +989,7 @@ class QDQQuantizer(BaseQuantizer): per_chan_overrides = self.tensor_quant_overrides.get_per_channel_overrides(tensor_name) axis = per_chan_overrides[0]["axis"] # Prefer axis from user-specified tensor-level overrides if available - weight_nparray = tensor_proto_to_array(weight_initializer) - weight_rank = len(weight_nparray.shape) + weight_rank = len(weight_initializer.dims) axis_valid, axis = normalize_axis(axis, weight_rank) if not axis_valid: logging.warning(f"Axis {axis} is out-of-range for weight '{tensor_name}' with rank {weight_rank}") diff --git a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py index 6050bd2e05..219d929d22 100644 --- a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py +++ b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py @@ -12,7 +12,7 @@ from typing import Any import onnx -from .quant_utils import QuantType, tensor_proto_to_array +from .quant_utils import QuantType @dataclass @@ -235,7 +235,7 @@ class TensorQuantOverridesHelper(MutableMapping): "the first channel dictionary.", ) - weight_shape = tensor_proto_to_array(initializers[tensor_name]).shape + weight_shape = list(initializers[tensor_name].dims) weight_rank = len(weight_shape) norm_axis = axis if norm_axis < 0: diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 8691471b04..21a772c5f5 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -5,7 +5,9 @@ # license information. # -------------------------------------------------------------------------- +import os import struct +import tempfile import unittest import numpy as np @@ -1150,6 +1152,48 @@ class TestTensorQuantOverridesOption(unittest.TestCase): self.assertEqual(set(qnn_config.op_types_to_quantize), {"Add"}) self.assertTrue(qnn_config.use_external_data_format) + def test_get_qnn_qdq_config_ext_data_separate_dir(self): + """ + Test that get_qnn_qdq_config() can validate per-channel quantization overrides for a model with external data + that is in a separate directory not in the cwd. + """ + + # Create model with a weight large enough (> 1024 bytes) to be stored externally. + large_weight = onnx.numpy_helper.from_array(np.random.random((1, 2, 32, 32)).astype(np.float32), "weight") + graph = onnx.helper.make_graph( + [onnx.helper.make_node("Conv", ["input", "weight"], ["output"])], + "conv_ext_data", + [onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 2, 64, 64))], + [onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, None)], + initializer=[large_weight], + ) + model = onnx.helper.make_model( + graph, + opset_imports=[onnx.helper.make_opsetid("", 21)], + ) + + # Make a separate directory in which to save model and its external data. + model_dir_path = tempfile.mkdtemp(prefix="model_ext_data") + model_name = "conv_ext_data.onnx" + model_path = os.path.join(model_dir_path, model_name) + + onnx.save_model( + model, + str(model_path), + save_as_external_data=True, + ) + + # Use tensor quantization overrides to quantize Conv's weight input to 4 bits on axis 0. + init_overrides = {"weight": [{"quant_type": QuantType.QInt4, "axis": 0, "symmetric": True}]} + + # get_qnn_qdq_config() should be able to validate the per-channel axis without having to load + # the external weight data. + qnn_config = get_qnn_qdq_config( + str(model_path), DummyDataReader([]), init_overrides=init_overrides # Dummy data reader does nothing + ) + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Conv"}) + self.assertTrue(qnn_config.use_external_data_format) + if __name__ == "__main__": t = TestTensorQuantOverridesOption()