diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 6a632d1c52..79426e7891 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -10,7 +10,7 @@ import itertools import uuid from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Sequence +from typing import Optional, Sequence import numpy as np import onnx diff --git a/onnxruntime/python/tools/quantization/save_activations.py b/onnxruntime/python/tools/quantization/qdq_loss_debug.py similarity index 62% rename from onnxruntime/python/tools/quantization/save_activations.py rename to onnxruntime/python/tools/quantization/qdq_loss_debug.py index 79e1ee2302..cda97d3d34 100644 --- a/onnxruntime/python/tools/quantization/save_activations.py +++ b/onnxruntime/python/tools/quantization/qdq_loss_debug.py @@ -50,7 +50,7 @@ from onnx import ModelProto, TensorProto, helper, numpy_helper import onnxruntime from .calibrate import CalibraterBase, CalibrationDataReader -from .quant_utils import clone_model_with_shape_infer +from .quant_utils import DEQUANT_OUTPUT_SUFFIX, QUANT_INPUT_SUFFIX, clone_model_with_shape_infer _TENSOR_SAVE_POSTFIX = "_ReshapedSavedOutput" _TENSOR_SAVE_POSTFIX_LEN = len(_TENSOR_SAVE_POSTFIX) @@ -145,3 +145,80 @@ def collect_activations( output_dict.setdefault(output_name, []).append(output_data) return output_dict + + +_POST_QDQ_POSTFIX1 = DEQUANT_OUTPUT_SUFFIX + "_1" + + +def _add_pre_post_qdq_pair( + qdq_cmp: Dict[str, Dict[str, Sequence[numpy.ndarray]]], + activation_name: str, + pre_qdq_tensors: Optional[Sequence[numpy.ndarray]], + post_qdq_tensors: Optional[Sequence[numpy.ndarray]], +) -> None: + if post_qdq_tensors and pre_qdq_tensors: + qdq_cmp[activation_name] = {} + qdq_cmp[activation_name]["pre_qdq"] = pre_qdq_tensors + qdq_cmp[activation_name]["post_qdq"] = post_qdq_tensors + + +def create_activation_matching( + qdq_activations: Dict[str, Sequence[numpy.ndarray]], + float_activations: Optional[Dict[str, Sequence[numpy.ndarray]]] = None, +) -> Dict[str, Dict[str, Sequence[numpy.ndarray]]]: + """Comparing activation values to help debugging accuracy loss due to quantization. + + This functions takes saved activations from the QDQ model and (optionally) the + float point model, and provides a data structure for comparing: + * from the qdq model, activation values before and after QDQ operation + * across both models, activations from the orignal model vs the corresponding + activations in the QDQ model + + Arg: + qdq_activations: Output of `collect_activations`. This must be from a quantized + model with QDQ format. + float_activations: Output of `collect_activations`. This must be from the float + point model. + + Returns: + Dict for comparing pre and post quantized activation tensors. E.g. + ``` + qdq_cmp = cmp_qdq_input_output(qdq_activations) + print(qdq_cmp['activation1']['pre_qdq'][0]) + print(qdq_cmp['activation1'][`post_qdq'][0]) + + + qdq_cmp = cmp_qdq_input_output(qdq_activations, float_activations) + print(qdq_cmp['activation1']['float'][0]) + print(qdq_cmp['activation1']['pre_qdq'][0]) + print(qdq_cmp['activation1'][`post_qdq'][0]) + ``` + """ + + qdq_cmp: Dict[str, Dict[str, Sequence[numpy.ndarray]]] = {} + for tensor_name, tensors in qdq_activations.items(): + if tensor_name.endswith(QUANT_INPUT_SUFFIX): + pre_name = tensor_name[: -len(QUANT_INPUT_SUFFIX)] + post_qdq_tensors = qdq_activations.get(pre_name) + pre_qdq_tensors = tensors + _add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors) + elif tensor_name.endswith(DEQUANT_OUTPUT_SUFFIX): + pre_name = tensor_name[: -len(DEQUANT_OUTPUT_SUFFIX)] + pre_qdq_tensors = qdq_activations.get(pre_name) + post_qdq_tensors = tensors + _add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors) + elif tensor_name.endswith(_POST_QDQ_POSTFIX1): + pre_name = tensor_name[: -len(_POST_QDQ_POSTFIX1)] + pre_qdq_tensors = qdq_activations.get(pre_name) + post_qdq_tensors = tensors + _add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors) + + if not float_activations: + return qdq_cmp + + for act_name, act_values in qdq_cmp.items(): + float_acts = float_activations.get(act_name) + if float_acts: + act_values["float"] = float_acts + + return qdq_cmp diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index d7f287f3d7..ccade81bd2 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -15,7 +15,10 @@ __version__ = "0.1.0" onnx_domain = "ai.onnx" ms_domain = "com.microsoft" QUANT_OP_NAME = "QuantizeLinear" +QUANT_INPUT_SUFFIX = "_QuantizeLinear_Input" DEQUANT_OP_NAME = "DequantizeLinear" +DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output" + type_to_name = { 1: "FLOAT", @@ -573,7 +576,7 @@ def add_quant_suffix(tensor_name): def add_quant_input_suffix(tensor_name): - return tensor_name + "_QuantizeLinear_Input" + return tensor_name + QUANT_INPUT_SUFFIX def add_quant_output_suffix(tensor_name): @@ -589,4 +592,4 @@ def add_dequant_input_suffix(tensor_name): def add_dequant_output_suffix(tensor_name): - return tensor_name + "_DequantizeLinear_Output" + return tensor_name + DEQUANT_OUTPUT_SUFFIX diff --git a/onnxruntime/test/python/quantization/test_save_activations.py b/onnxruntime/test/python/quantization/test_qdq_loss_debug.py similarity index 52% rename from onnxruntime/test/python/quantization/test_save_activations.py rename to onnxruntime/test/python/quantization/test_qdq_loss_debug.py index 9b635b3dfb..b50653a3d0 100644 --- a/onnxruntime/test/python/quantization/test_save_activations.py +++ b/onnxruntime/test/python/quantization/test_qdq_loss_debug.py @@ -9,14 +9,20 @@ import tempfile import unittest from pathlib import Path +from typing import Dict, List import numpy as np import onnx from onnx import TensorProto, helper, numpy_helper import onnxruntime +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static from onnxruntime.quantization.calibrate import CalibrationDataReader -from onnxruntime.quantization.save_activations import collect_activations, modify_model_output_intermediate_tensors +from onnxruntime.quantization.qdq_loss_debug import ( + collect_activations, + create_activation_matching, + modify_model_output_intermediate_tensors, +) def generate_input_initializer(tensor_shape, tensor_dtype, input_name): @@ -28,53 +34,53 @@ def generate_input_initializer(tensor_shape, tensor_dtype, input_name): return init -def construct_test_model1(test_model_path): +def construct_test_model1(test_model_path: str, activations_as_outputs=False): """ Create an ONNX model shaped as: ``` (input) | - Relu - / \ - Conv \ - | \ - Relu Conv - | | - Conv | - \ / + Relu1 + / \ + Conv1 \ + | \ + Relu2 Conv3 + | | + Conv2 | + \ / Add | - (X6) + (AddOut) ``` We are keeping all intermediate tensors as output, just for test verification purposes """ input_vi = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 1, 3]) - x1_output = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [1, 3, 1, 3]) - x2_output = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [1, 3, 1, 3]) - x3_output = helper.make_tensor_value_info("X3", TensorProto.FLOAT, [1, 3, 1, 3]) - x4_output = helper.make_tensor_value_info("X4", TensorProto.FLOAT, [1, 3, 1, 3]) - x5_output = helper.make_tensor_value_info("X5", TensorProto.FLOAT, [1, 3, 1, 3]) - x6_output = helper.make_tensor_value_info("X6", TensorProto.FLOAT, [1, 3, 1, 3]) + x1_output = helper.make_tensor_value_info("Relu1Out", TensorProto.FLOAT, [1, 3, 1, 3]) + x2_output = helper.make_tensor_value_info("Conv1Out", TensorProto.FLOAT, [1, 3, 1, 3]) + x3_output = helper.make_tensor_value_info("Relu2Out", TensorProto.FLOAT, [1, 3, 1, 3]) + x4_output = helper.make_tensor_value_info("Conv2Out", TensorProto.FLOAT, [1, 3, 1, 3]) + x5_output = helper.make_tensor_value_info("Conv3Out", TensorProto.FLOAT, [1, 3, 1, 3]) + x6_output = helper.make_tensor_value_info("AddOut", TensorProto.FLOAT, [1, 3, 1, 3]) w1 = generate_input_initializer([3, 3, 1, 1], np.float32, "W1") b1 = generate_input_initializer([3], np.float32, "B1") w3 = generate_input_initializer([3, 3, 1, 1], np.float32, "W3") b3 = generate_input_initializer([3], np.float32, "B3") w5 = generate_input_initializer([3, 3, 1, 1], np.float32, "W5") b5 = generate_input_initializer([3], np.float32, "B5") - relu_node_1 = helper.make_node("Relu", ["input"], ["X1"], name="Relu1") - conv_node_1 = helper.make_node("Conv", ["X1", "W1", "B1"], ["X2"], name="Conv1") - relu_node_2 = helper.make_node("Relu", ["X2"], ["X3"], name="Relu2") - conv_node_2 = helper.make_node("Conv", ["X3", "W3", "B3"], ["X4"], name="Conv2") - conv_node_3 = helper.make_node("Conv", ["X1", "W5", "B5"], ["X5"], name="Conv3") - add_node = helper.make_node("Add", ["X4", "X5"], ["X6"], name="Add") + relu_node_1 = helper.make_node("Relu", ["input"], ["Relu1Out"], name="Relu1") + conv_node_1 = helper.make_node("Conv", ["Relu1Out", "W1", "B1"], ["Conv1Out"], name="Conv1") + relu_node_2 = helper.make_node("Relu", ["Conv1Out"], ["Relu2Out"], name="Relu2") + conv_node_2 = helper.make_node("Conv", ["Relu2Out", "W3", "B3"], ["Conv2Out"], name="Conv2") + conv_node_3 = helper.make_node("Conv", ["Relu1Out", "W5", "B5"], ["Conv3Out"], name="Conv3") + add_node = helper.make_node("Add", ["Conv2Out", "Conv3Out"], ["AddOut"], name="Add") # we are keeping all tensors in the output anyway for verification purpose + outputs = [x6_output] + if activations_as_outputs: + outputs.extend([x1_output, x2_output, x3_output, x4_output, x5_output]) graph = helper.make_graph( - [relu_node_1, conv_node_1, relu_node_2, conv_node_2, conv_node_3, add_node], - "test_graph_4", - [input_vi], - [x1_output, x2_output, x3_output, x4_output, x5_output, x6_output], + [relu_node_1, conv_node_1, relu_node_2, conv_node_2, conv_node_3, add_node], "test_graph_4", [input_vi], outputs ) graph.initializer.add().CopyFrom(w1) graph.initializer.add().CopyFrom(b1) @@ -108,6 +114,21 @@ class TestDataReader(CalibrationDataReader): self.preprocess_flag = True +def augment_model_collect_activations( + model_path: str, augmented_model_path: str, data_reader: TestDataReader +) -> Dict[str, List[np.ndarray]]: + aug_model = modify_model_output_intermediate_tensors(model_path) + + onnx.save( + aug_model, + augmented_model_path, + save_as_external_data=False, + ) + + tensor_dict = collect_activations(augmented_model_path, data_reader) + return tensor_dict + + class TestSaveActivations(unittest.TestCase): @classmethod def setUpClass(cls): @@ -118,20 +139,12 @@ class TestSaveActivations(unittest.TestCase): cls._tmp_model_dir.cleanup() def test_saved_tensors_match_internal_tensors(self): - test_model_path = str(Path(self._tmp_model_dir.name) / "augmented_model.onnx") - construct_test_model1(test_model_path) + test_model_path = str(Path(self._tmp_model_dir.name) / "test_model1.onnx") + construct_test_model1(test_model_path, activations_as_outputs=True) data_reader = TestDataReader() - aug_model = modify_model_output_intermediate_tensors(test_model_path) augmented_model_path = str(Path(self._tmp_model_dir.name).joinpath("augmented_test_model_1.onnx")) - - onnx.save( - aug_model, - augmented_model_path, - save_as_external_data=False, - ) - - tensor_dict = collect_activations(augmented_model_path, data_reader) + tensor_dict = augment_model_collect_activations(test_model_path, augmented_model_path, data_reader) # run original model and compare the tensors sess_options = onnxruntime.SessionOptions() @@ -160,6 +173,48 @@ class TestSaveActivations(unittest.TestCase): act = actual.reshape(-1) np.testing.assert_equal(exp, act) + def test_create_activation_matching_present(self): + float_model_path = str(Path(self._tmp_model_dir.name) / "float_model2.onnx") + construct_test_model1(float_model_path, activations_as_outputs=False) + data_reader = TestDataReader() + + qdq_model_path = str(Path(self._tmp_model_dir.name) / "qdq_model2.onnx") + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + per_channel=False, + reduce_range=False, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + ) + + data_reader.rewind() + augmented_float_model_path = str(Path(self._tmp_model_dir.name).joinpath("augmented_float_model2.onnx")) + float_activations = augment_model_collect_activations(float_model_path, augmented_float_model_path, data_reader) + + data_reader.rewind() + augmented_qdq_model_path = str(Path(self._tmp_model_dir.name).joinpath("augmented_qdq_model2.onnx")) + qdq_activations = augment_model_collect_activations(qdq_model_path, augmented_qdq_model_path, data_reader) + + compare_dict = create_activation_matching(qdq_activations, float_activations) + + # 'Conv1Out' is combined with 'Relu2Out' + tensor_names = [ + "Relu1Out", + "Relu2Out", + "Conv2Out", + "Conv3Out", + "AddOut", + ] + for tensor_name in tensor_names: + self.assertTrue(compare_dict[tensor_name]["float"]) + self.assertTrue(compare_dict[tensor_name]["pre_qdq"]) + self.assertTrue(compare_dict[tensor_name]["post_qdq"]) + + self.assertFalse(compare_dict.get("Conv1Out")) + if __name__ == "__main__": unittest.main()