From a4fdb4dbd9ce1336d4d35b7ae784df04e93c1cf1 Mon Sep 17 00:00:00 2001 From: Zhang Lei Date: Thu, 8 Apr 2021 18:00:35 -0700 Subject: [PATCH] Support transpose by merge Reshape etc into direct xint8 operators. (#7265) * Suppose transpose by merge Reshape etc into direct xint8 operators. * Add resize operator quantization support * Add QDQ tests for resize, reshape, maxpool, transpose. --- .../operators/{reshape.py => direct_q8.py} | 26 ++-- .../tools/quantization/operators/maxpool.py | 41 +++--- .../tools/quantization/operators/resize.py | 34 +++++ .../python/tools/quantization/registry.py | 19 ++- .../test/python/quantization/op_test_utils.py | 5 + .../python/quantization/test_op_maxpool.py | 93 ++++++++++++++ .../python/quantization/test_op_reshape.py | 38 ++++-- .../python/quantization/test_op_resize.py | 119 ++++++++++++++++++ .../python/quantization/test_op_transpose.py | 93 ++++++++++++++ 9 files changed, 419 insertions(+), 49 deletions(-) rename onnxruntime/python/tools/quantization/operators/{reshape.py => direct_q8.py} (57%) create mode 100644 onnxruntime/python/tools/quantization/operators/resize.py create mode 100644 onnxruntime/test/python/quantization/test_op_maxpool.py create mode 100644 onnxruntime/test/python/quantization/test_op_resize.py create mode 100644 onnxruntime/test/python/quantization/test_op_transpose.py diff --git a/onnxruntime/python/tools/quantization/operators/reshape.py b/onnxruntime/python/tools/quantization/operators/direct_q8.py similarity index 57% rename from onnxruntime/python/tools/quantization/operators/reshape.py rename to onnxruntime/python/tools/quantization/operators/direct_q8.py index f856cb6837..62835966d4 100644 --- a/onnxruntime/python/tools/quantization/operators/reshape.py +++ b/onnxruntime/python/tools/quantization/operators/direct_q8.py @@ -1,31 +1,37 @@ -import onnx from .base_operator import QuantOperatorBase -from ..quant_utils import QuantizedValue, QuantizedValueType -from onnx import onnx_pb as onnx_proto +from .qdq_base_operator import QDQOperatorBase +from ..quant_utils import QuantizedValue - -class ReshapeQuant(QuantOperatorBase): +# For operators that support 8bits operations directly, and output could +# reuse input[0]'s type, zeropoint, scale; For example,Transpose, Reshape, etc. +class Direct8BitOp(QuantOperatorBase): def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) def quantize(self): node = self.node - assert (node.op_type == "Reshape") - # If input to this node is not quantized then keep this node + # Quantize when input[0] is quantized already. Otherwise keep it. if node.input[0] not in self.quantizer.quantized_value_map: self.quantizer.new_nodes += [node] return - # Reshape is a no-op in terms of quantization + # Create an entry for output quantized value quantized_input_value = self.quantizer.quantized_value_map[node.input[0]] quantized_output_value = QuantizedValue(node.output[0], node.output[0] + "_quantized", quantized_input_value.scale_name, quantized_input_value.zp_name, - QuantizedValueType.Input) - # Create an entry for output quantized value + quantized_input_value.value_type) self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value node.input[0] = quantized_input_value.q_name node.output[0] = quantized_output_value.q_name self.quantizer.new_nodes += [node] + +class QDQDirect8BitOp(QDQOperatorBase): + def __init__(self, onnx_quantizer, onnx_node): + self.quantizer = onnx_quantizer + self.node = onnx_node + + def quantize(self): + self.quantizer.quantize_tensor(self.node.input[0]) diff --git a/onnxruntime/python/tools/quantization/operators/maxpool.py b/onnxruntime/python/tools/quantization/operators/maxpool.py index 8cc1b7da7a..1eb2ce5565 100644 --- a/onnxruntime/python/tools/quantization/operators/maxpool.py +++ b/onnxruntime/python/tools/quantization/operators/maxpool.py @@ -1,10 +1,7 @@ -import onnx -from .base_operator import QuantOperatorBase -from ..quant_utils import QuantizedValue, QuantizedValueType -from onnx import onnx_pb as onnx_proto +from .direct_q8 import Direct8BitOp, QDQDirect8BitOp -class QMaxPool(QuantOperatorBase): +class QMaxPool(Direct8BitOp): def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) @@ -12,24 +9,26 @@ class QMaxPool(QuantOperatorBase): node = self.node assert (node.op_type == "MaxPool") + # if version is less than 12, go to normal quantize. if self.quantizer.opset_version < 12: - super().quantize() + super(Direct8BitOp, self).quantize() return - # When mode is QLinearOps, the output quantization params are calculated based on outputs from - # activation nodes, therefore these nodes can be removed from the graph if they follow a quantized op. - # If input to this node is not quantized then keep this node - if node.input[0] not in self.quantizer.quantized_value_map: - self.quantizer.new_nodes += [node] + # Direct 8bits op + return super().quantize() + + +class QDQMaxPool(QDQDirect8BitOp): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def quantize(self): + node = self.node + assert (node.op_type == "MaxPool") + + # if version is less than 12, just no change + if self.quantizer.opset_version < 12: return - # Create an entry for output quantized value - quantized_input_value = self.quantizer.quantized_value_map[node.input[0]] - quantized_output_value = QuantizedValue(node.output[0], node.output[0] + "_quantized", - quantized_input_value.scale_name, quantized_input_value.zp_name, - QuantizedValueType.Input) - self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value - - node.input[0] = quantized_input_value.q_name - node.output[0] = quantized_output_value.q_name - self.quantizer.new_nodes += [node] + # Direct 8bits op + return super().quantize() diff --git a/onnxruntime/python/tools/quantization/operators/resize.py b/onnxruntime/python/tools/quantization/operators/resize.py new file mode 100644 index 0000000000..c07cd99068 --- /dev/null +++ b/onnxruntime/python/tools/quantization/operators/resize.py @@ -0,0 +1,34 @@ +from .direct_q8 import Direct8BitOp, QDQDirect8BitOp + + +class QResize(Direct8BitOp): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def quantize(self): + node = self.node + assert (node.op_type == "Resize") + + # if version is less than 11, go to normal quantize. + if self.quantizer.opset_version < 11: + super(Direct8BitOp, self).quantize() + return + + # Direct 8bits op + return super().quantize() + + +class QDQResize(QDQDirect8BitOp): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def quantize(self): + node = self.node + assert (node.op_type == "Resize") + + # if version is less than 11, just keep this node + if self.quantizer.opset_version < 11: + return + + # Direct 8bits op + return super().quantize() diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index 26c7daf299..c58ebd3588 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -8,16 +8,18 @@ from .operators.gather import GatherQuant from .operators.conv import QLinearConv, ConvInteger, QDQConv from .operators.activation import QLinearActivation, QDQRemovableActivation from .operators.binary_op import QLinearBinaryOp -from .operators.maxpool import QMaxPool +from .operators.maxpool import QDQMaxPool, QMaxPool from .operators.gavgpool import QGlobalAveragePool from .operators.lstm import LSTMQuant from .operators.split import QSplit from .operators.pad import QPad -from .operators.reshape import ReshapeQuant +from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp +from .operators.resize import QResize, QDQResize -CommonOpsRegistry = {"Gather": GatherQuant, - "EmbedLayerNormalization": EmbedLayerNormalizationQuant, - "Reshape": ReshapeQuant} +CommonOpsRegistry = { + "Gather": GatherQuant, + "EmbedLayerNormalization": EmbedLayerNormalizationQuant, +} IntegerOpsRegistry = { "Conv": ConvInteger, @@ -40,6 +42,9 @@ QLinearOpsRegistry = { "GlobalAveragePool": QGlobalAveragePool, "Split": QSplit, "Pad": QPad, + "Reshape": Direct8BitOp, + "Transpose" : Direct8BitOp, + "Resize": QResize, } QLinearOpsRegistry.update(CommonOpsRegistry) @@ -47,6 +52,10 @@ QDQRegistry = { "Conv": QDQConv, "Clip": QDQRemovableActivation, "Relu": QDQRemovableActivation, + "Reshape": QDQDirect8BitOp, + "Transpose" : QDQDirect8BitOp, + "Resize": QDQResize, + "MaxPool": QDQMaxPool, } diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index cbb1a2b971..58a07330ee 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -57,3 +57,8 @@ def check_model_correctness(testcase, model_path_origin, model_path_to_check, in for idx, ref_output in enumerate(origin_results): output = target_results[idx] np.testing.assert_allclose(ref_output, output, rtol=rtol, atol=atol) + +def check_op_nodes(testcase, model_path, node_checker): + model = onnx.load(Path(model_path)) + for node in model.graph.node: + testcase.assertTrue(node_checker(node)) diff --git a/onnxruntime/test/python/quantization/test_op_maxpool.py b/onnxruntime/test/python/quantization/test_op_maxpool.py new file mode 100644 index 0000000000..70fe9f2b67 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_maxpool.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +import onnx +import numpy as np +from onnx import helper, TensorProto +from onnxruntime.quantization import quantize_static, QuantFormat +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes + + +class TestOpMaxPool(unittest.TestCase): + def input_feeds(self, n, name2shape): + input_data_list = [] + for i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_conv_maxpool(self, output_model_path, + conv_input_shape, conv_weight_shape, + maxpool_input_shape, maxpool_attributes, + output_shape, + ): + # (input) + # \ + # Conv + # / \ + # Identity MaxPool + # / \ + # (identity_out) (output) + input_tensor = helper.make_tensor_value_info('input', TensorProto.FLOAT, conv_input_shape) + + conv_weight_arr = np.random.randint(-1, 2, conv_weight_shape).astype(np.float32) + conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name='conv1_weight') + conv_node = onnx.helper.make_node('Conv', ['input', 'conv1_weight'], ['conv_output'], name='conv_node') + + identity_out = helper.make_tensor_value_info('identity_out', TensorProto.FLOAT, maxpool_input_shape) + identity_node = helper.make_node('Identity', ['conv_output'], ['identity_out'], name='IdentityNode') + + initializers = [conv_weight_initializer] + + output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, output_shape) + maxpool_node = helper.make_node('MaxPool', ['conv_output'], ['output'], name='maxpool_node', **maxpool_attributes) + + graph = helper.make_graph([conv_node, identity_node, maxpool_node], 'TestOpQuantizerMaxPool_test_model', + [input_tensor], [identity_out, output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 12)]) + model.ir_version = onnx.IR_VERSION + onnx.save(model, output_model_path) + + def test_quantize_maxpool(self): + np.random.seed(1) + + model_fp32_path = 'maxpool_fp32.onnx' + model_uint8_path = 'maxpool_uint8.onnx' + model_uint8_qdq_path = 'maxpool_uint8_qdq.onnx' + + self.construct_model_conv_maxpool(model_fp32_path, + [1, 2, 26, 42], [3, 2, 3, 3], + [1, 3, 24, 40], {'kernel_shape': [3, 3]}, + [1, 3, 22, 38]) + + # Verify QOperator mode + data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]}) + quantize_static(model_fp32_path, model_uint8_path, data_reader) + + # make sure maxpool become xint8 operator, its input name could tell that + check_op_nodes(self, model_uint8_path, lambda node: (node.name != "maxpool_node" or node.input[0] != 'conv_output')) + qnode_counts = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 2, 'MaxPool': 1} + check_op_type_count(self, model_uint8_path, **qnode_counts) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next()) + + # Verify QDQ mode + data_reader.rewind() + quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ) + qdqnode_counts = {'Conv': 1, 'QuantizeLinear': 2, 'DequantizeLinear': 3, 'MaxPool': 1} + check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next()) + + +if __name__ == '__main__': + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_reshape.py b/onnxruntime/test/python/quantization/test_op_reshape.py index db5c49dc49..c7df10da9f 100644 --- a/onnxruntime/test/python/quantization/test_op_reshape.py +++ b/onnxruntime/test/python/quantization/test_op_reshape.py @@ -10,8 +10,9 @@ import unittest import onnx import numpy as np from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count +from onnxruntime.quantization import quantize_static, QuantFormat +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes + class TestOpReshape(unittest.TestCase): def input_feeds(self, n, name2shape): @@ -55,7 +56,6 @@ class TestOpReshape(unittest.TestCase): initializers.append(onnx.numpy_helper.from_array(np.array(output_shape, dtype=np.int64), name=reshape_shape)) reshape_node = onnx.helper.make_node('Reshape', reshape_inputs, reshape_output, name=reshape_name) - # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape) @@ -71,19 +71,31 @@ class TestOpReshape(unittest.TestCase): np.random.seed(1) model_fp32_path = 'reshape_fp32.onnx' model_uint8_path = 'reshape_uint8.onnx' - data_reader = self.input_feeds(1, {'input': [3, 7]}) + model_uint8_qdq_path = 'reshape_uint8_qdq.onnx' + self.construct_model_matmul_reshape(model_fp32_path, - [3, 7], - [7, 3], - [1, 9]) - quantize_static(model_fp32_path, - model_uint8_path, - data_reader - ) + [3, 7], + [7, 3], + [1, 9]) + + # Verify QOperator mode + data_reader = self.input_feeds(1, {'input': [3, 7]}) + quantize_static(model_fp32_path, model_uint8_path, data_reader) + # make sure transpose become xint8 operator, its input name could tell that + check_op_nodes(self, model_uint8_path, lambda node: (node.name != "reshape_node" or node.input[0] != 'matmul_output')) + qnode_counts = {'QLinearMatMul': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 1, 'Reshape': 1} + check_op_type_count(self, model_uint8_path, **qnode_counts) data_reader.rewind() - qdq_nodes = {'QLinearMatMul': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 1, 'Reshape': 1} - check_op_type_count(self, model_uint8_path, **qdq_nodes) check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next()) + # Verify QDQ mode + data_reader.rewind() + quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ) + qdqnode_counts = {'MatMul': 1, 'QuantizeLinear': 2, 'DequantizeLinear': 3, 'Reshape': 1} + check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next()) + + if __name__ == '__main__': unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_resize.py b/onnxruntime/test/python/quantization/test_op_resize.py new file mode 100644 index 0000000000..250e8e62cf --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_resize.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +import onnx +import numpy as np +from onnx import helper, TensorProto +from onnxruntime.quantization import quantize_static, QuantFormat +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes + + +class TestOpResize(unittest.TestCase): + def input_feeds(self, n, name2shape): + input_data_list = [] + for i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_conv_resize(self, output_model_path, + conv_input_shape, conv_weight_shape, + resize_input_shape, resize_output_shape, + resize_attrs, + resize_roi, resize_scales, resize_sizes): + # (input) + # \ + # Conv + # / \ + # Identity Resize + # / \ + # (identity_out) (output) + input_tensor = helper.make_tensor_value_info('input', TensorProto.FLOAT, conv_input_shape) + + conv_weight_arr = np.random.randint(-1, 2, conv_weight_shape).astype(np.float32) + conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name='conv1_weight') + conv_node = onnx.helper.make_node('Conv', ['input', 'conv1_weight'], ['conv_output'], name='conv_node') + + identity_out = helper.make_tensor_value_info('identity_out', TensorProto.FLOAT, resize_input_shape) + identity_node = helper.make_node('Identity', ['conv_output'], ['identity_out'], name='IdentityNode') + + initializers = [conv_weight_initializer] + + output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, resize_output_shape) + resize_inputs = ['conv_output'] # resize_roi_name, resize_scales_name, resize_sizes_name] + resize_node = helper.make_node('Resize', resize_inputs, ['output'], name='resize_node', **resize_attrs) + + if (resize_roi is not None): + resize_roi_name = 'resize_roi' + resize_roi_initializer = helper.make_tensor(resize_roi_name, TensorProto.FLOAT, [len(resize_roi)], resize_roi) + initializers.extend([resize_roi_initializer]) + resize_node.input.extend([resize_roi_name]) + else: + resize_node.input.extend(['']) + + if (resize_scales is not None): + resize_scales_name = 'resize_scales' + resize_scales_initializer = helper.make_tensor(resize_scales_name, TensorProto.FLOAT, [ + len(resize_scales)], resize_scales) + initializers.extend([resize_scales_initializer]) + resize_node.input.extend([resize_scales_name]) + else: + resize_node.input.extend(['']) + + if (resize_sizes is not None): + resize_sizes_name = 'resize_sizes' + resize_sizes_initializer = helper.make_tensor(resize_sizes_name, TensorProto.INT64, [len(resize_sizes)], resize_sizes) + initializers.extend([resize_sizes_initializer]) + resize_node.input.extend([resize_sizes_name]) + + graph = helper.make_graph([conv_node, identity_node, resize_node], 'TestOpQuantizerResize_test_model', + [input_tensor], [identity_out, output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = onnx.IR_VERSION + onnx.save(model, output_model_path) + + def test_quantize_resize(self): + np.random.seed(1) + + model_fp32_path = 'resize_fp32.onnx' + model_uint8_path = 'resize_uint8.onnx' + model_uint8_qdq_path = 'resize_uint8_qdq.onnx' + + kwargs = {'coordinate_transformation_mode': 'asymmetric', 'mode': 'nearest', 'nearest_mode': 'floor'} + self.construct_model_conv_resize(model_fp32_path, + [1, 2, 26, 42], [3, 2, 3, 3], + [1, 3, 24, 40], [1, 3, 48, 80], + kwargs, + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 2.0, 2.0], None) + + # Verify QOperator mode + data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]}) + quantize_static(model_fp32_path, model_uint8_path, data_reader) + + # make sure resize become xint8 operator, its input name could tell that + check_op_nodes(self, model_uint8_path, lambda node: (node.name != "resize_node" or node.input[0] != 'conv_output')) + qnode_counts = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 2, 'Resize': 1} + check_op_type_count(self, model_uint8_path, **qnode_counts) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next()) + + # Verify QDQ mode + data_reader.rewind() + quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ) + qdqnode_counts = {'Conv': 1, 'QuantizeLinear': 2, 'DequantizeLinear': 3, 'Resize': 1} + check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next()) + + +if __name__ == '__main__': + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_transpose.py b/onnxruntime/test/python/quantization/test_op_transpose.py new file mode 100644 index 0000000000..d4f5316dda --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_transpose.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +import onnx +import numpy as np +from onnx import helper, TensorProto +from onnxruntime.quantization import quantize_static, QuantFormat +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes + + +class TestOpTranspose(unittest.TestCase): + def input_feeds(self, n, name2shape): + input_data_list = [] + for i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_matmul_transpose(self, output_model_path, input_shape, weight_shape, output_shape): + # (input) + # | + # MatMul + # | + # Transpose + # | + # (output) + input_name = 'input' + output_name = 'output' + initializers = [] + + # make MatMul node + weight_name = 'matmul_weight' + matmul_output_name = 'matmul_output' + matmul_inputs = [input_name, weight_name] + matmul_outputs = [matmul_output_name] + matmul_name = 'matmul_node' + matmul_weight_data = np.random.normal(0, 0.1, weight_shape).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(matmul_weight_data, name=weight_name)) + matmul_node = onnx.helper.make_node('MatMul', matmul_inputs, matmul_outputs, name=matmul_name) + + # make Transpose node + kwargs = {'perm': (1, 0)} + transpose_node = onnx.helper.make_node('Transpose', [matmul_output_name], [output_name], name="transpose_node", **kwargs) + + # make graph + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape) + graph_name = 'Transpose_Quant_Test' + graph = helper.make_graph([matmul_node, transpose_node], graph_name, + [input_tensor], [output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + model.ir_version = onnx.IR_VERSION + + onnx.save(model, output_model_path) + + def test_quantize_transpose(self): + np.random.seed(1) + model_fp32_path = 'transpose_fp32.onnx' + model_uint8_path = 'transpose_uint8.onnx' + model_uint8_qdq_path = 'transpose_uint8_qdq.onnx' + + self.construct_model_matmul_transpose(model_fp32_path, [3, 7], [7, 5], [5, 3]) + + # Verify QOperator model + data_reader = self.input_feeds(1, {'input': [3, 7]}) + quantize_static(model_fp32_path, model_uint8_path, data_reader) + # make sure transpose become xint8 operator, its input name could tell that + check_op_nodes(self, model_uint8_path, lambda node: (node.name != "transpose_node" or node.input[0] != 'matmul_output')) + qnode_counts = {'QLinearMatMul': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 1, 'Transpose': 1} + check_op_type_count(self, model_uint8_path, **qnode_counts) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next()) + + # Verify QDQ model + data_reader.rewind() + quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ) + qdqnode_counts = {'MatMul': 1, 'QuantizeLinear': 2, 'DequantizeLinear': 3, 'Transpose': 1} + check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next()) + + +if __name__ == '__main__': + unittest.main()