diff --git a/onnxruntime/python/tools/quantization/qdq_loss_debug.py b/onnxruntime/python/tools/quantization/qdq_loss_debug.py index 4983811c21..a3adf675d8 100644 --- a/onnxruntime/python/tools/quantization/qdq_loss_debug.py +++ b/onnxruntime/python/tools/quantization/qdq_loss_debug.py @@ -167,7 +167,7 @@ def _add_pre_post_qdq_pair( pre_qdq_tensors: Optional[Sequence[numpy.ndarray]], post_qdq_tensors: Optional[Sequence[numpy.ndarray]], ) -> None: - if post_qdq_tensors and pre_qdq_tensors: + if post_qdq_tensors is not None and pre_qdq_tensors is not None: qdq_cmp[activation_name] = {} qdq_cmp[activation_name]["pre_qdq"] = pre_qdq_tensors qdq_cmp[activation_name]["post_qdq"] = post_qdq_tensors @@ -229,7 +229,7 @@ def create_activation_matching( for act_name, act_values in qdq_cmp.items(): float_acts = float_activations.get(act_name) - if float_acts: + if float_acts is not None: act_values["float"] = float_acts return qdq_cmp diff --git a/onnxruntime/test/python/quantization/test_qdq_loss_debug.py b/onnxruntime/test/python/quantization/test_qdq_loss_debug.py index b09d8d7748..5a26cd3611 100644 --- a/onnxruntime/test/python/quantization/test_qdq_loss_debug.py +++ b/onnxruntime/test/python/quantization/test_qdq_loss_debug.py @@ -20,6 +20,7 @@ import onnxruntime from onnxruntime.quantization import QuantFormat, QuantType, quantize_static from onnxruntime.quantization.calibrate import CalibrationDataReader from onnxruntime.quantization.qdq_loss_debug import ( + QUANT_INPUT_SUFFIX, collect_activations, compute_activation_error, compute_weight_error, @@ -318,6 +319,12 @@ class TestSaveActivations(unittest.TestCase): dq_array = matched_weights[weight_name]["dequantized"] self.assertEqual(float_array.shape, dq_array.shape) + def test_none_test(self): + a = np.array([2, 3, 4]) + b = np.array([7, 8, 9]) + c = np.array([1, 2, 3]) + create_activation_matching({"test" + QUANT_INPUT_SUFFIX: a, "test": c}, {"test": b}) + if __name__ == "__main__": unittest.main()