From e8c0f5d0ffd2e7a27ee6c6d768ef29f77df49661 Mon Sep 17 00:00:00 2001 From: Peichen Xie Date: Wed, 18 Nov 2020 13:24:48 +0800 Subject: [PATCH] Update the quantization script to support GEMM (transB==1) (#5432) * Modify onnx_quantizer.py * Fix topology order issues * Handle more cases --- .../tools/quantization/onnx_quantizer.py | 54 ++++++++++++++----- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index dab4ebeda6..2e3f3c4fa2 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -135,8 +135,7 @@ class ONNXQuantizer: return opset_version def replace_gemm_with_matmul(self): - nodes_to_remove = [] - nodes_to_add = [] + new_nodes = [] for node in self.model.nodes(): if node.op_type == 'Gemm': @@ -153,21 +152,48 @@ class ONNXQuantizer: transA = onnx.helper.get_attribute_value(attr) elif attr.name == 'transB': transB = onnx.helper.get_attribute_value(attr) - if alpha == 1.0 and beta == 1.0 and transA == 0 and transB == 0: - matmul_node = onnx.helper.make_node('MatMul', [node.input[0], node.input[1]], - [node.output[0] + '_MatMul'], - name=node.output[0] + '_MatMul') + if alpha == 1.0 and beta == 1.0 and transA == 0: + inputB = node.input[1] + if transB == 1: + B = self.model.get_initializer(node.input[1]) + if B: + # assume B is not used by any other node + B_array = onnx.numpy_helper.to_array(B) + B_trans = onnx.numpy_helper.from_array(B_array.T) + B_trans.name = B.name + self.model.remove_initializer(B) + self.model.add_initializer(B_trans) + else: + inputB += '_Transposed' + transpose_node = onnx.helper.make_node('Transpose', + inputs=[node.input[1]], + outputs=[inputB], + name=node.name+'_Transpose') + new_nodes.append(transpose_node) - add_node = onnx.helper.make_node('Add', - inputs=[node.output[0] + '_MatMul', node.input[2]], - outputs=node.output, - name=node.output[0] + '_Add') + matmul_node = onnx.helper.make_node('MatMul', + inputs=[node.input[0], inputB], + outputs=[node.output[0] + ('_MatMul' if len(node.input)>2 else '')], + name=node.name + '_MatMul') + new_nodes.append(matmul_node) - nodes_to_add.extend([matmul_node, add_node]) - nodes_to_remove.extend([node]) + if len(node.input) > 2: + add_node = onnx.helper.make_node('Add', + inputs=[node.output[0] + '_MatMul', node.input[2]], + outputs=node.output, + name=node.name + '_Add') + new_nodes.append(add_node) + + # unsupported + else: + new_nodes.append(node) + + # not GEMM + else: + new_nodes.append(node) - self.model.add_nodes(nodes_to_add) - self.model.remove_nodes(nodes_to_remove) + self.model.graph().ClearField('node') + self.model.graph().node.extend(new_nodes) def remove_fake_quantized_nodes(self): '''