Update the quantization script to support GEMM (transB==1) (#5432)

* Modify onnx_quantizer.py

* Fix topology order issues

* Handle more cases
This commit is contained in:
Peichen Xie 2020-11-18 13:24:48 +08:00 committed by GitHub
parent f964bb94ba
commit e8c0f5d0ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -135,8 +135,7 @@ class ONNXQuantizer:
return opset_version
def replace_gemm_with_matmul(self):
nodes_to_remove = []
nodes_to_add = []
new_nodes = []
for node in self.model.nodes():
if node.op_type == 'Gemm':
@ -153,21 +152,48 @@ class ONNXQuantizer:
transA = onnx.helper.get_attribute_value(attr)
elif attr.name == 'transB':
transB = onnx.helper.get_attribute_value(attr)
if alpha == 1.0 and beta == 1.0 and transA == 0 and transB == 0:
matmul_node = onnx.helper.make_node('MatMul', [node.input[0], node.input[1]],
[node.output[0] + '_MatMul'],
name=node.output[0] + '_MatMul')
if alpha == 1.0 and beta == 1.0 and transA == 0:
inputB = node.input[1]
if transB == 1:
B = self.model.get_initializer(node.input[1])
if B:
# assume B is not used by any other node
B_array = onnx.numpy_helper.to_array(B)
B_trans = onnx.numpy_helper.from_array(B_array.T)
B_trans.name = B.name
self.model.remove_initializer(B)
self.model.add_initializer(B_trans)
else:
inputB += '_Transposed'
transpose_node = onnx.helper.make_node('Transpose',
inputs=[node.input[1]],
outputs=[inputB],
name=node.name+'_Transpose')
new_nodes.append(transpose_node)
add_node = onnx.helper.make_node('Add',
inputs=[node.output[0] + '_MatMul', node.input[2]],
outputs=node.output,
name=node.output[0] + '_Add')
matmul_node = onnx.helper.make_node('MatMul',
inputs=[node.input[0], inputB],
outputs=[node.output[0] + ('_MatMul' if len(node.input)>2 else '')],
name=node.name + '_MatMul')
new_nodes.append(matmul_node)
nodes_to_add.extend([matmul_node, add_node])
nodes_to_remove.extend([node])
if len(node.input) > 2:
add_node = onnx.helper.make_node('Add',
inputs=[node.output[0] + '_MatMul', node.input[2]],
outputs=node.output,
name=node.name + '_Add')
new_nodes.append(add_node)
# unsupported
else:
new_nodes.append(node)
# not GEMM
else:
new_nodes.append(node)
self.model.add_nodes(nodes_to_add)
self.model.remove_nodes(nodes_to_remove)
self.model.graph().ClearField('node')
self.model.graph().node.extend(new_nodes)
def remove_fake_quantized_nodes(self):
'''