mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Update the quantization script to support GEMM (transB==1) (#5432)
* Modify onnx_quantizer.py * Fix topology order issues * Handle more cases
This commit is contained in:
parent
f964bb94ba
commit
e8c0f5d0ff
1 changed files with 40 additions and 14 deletions
|
|
@ -135,8 +135,7 @@ class ONNXQuantizer:
|
|||
return opset_version
|
||||
|
||||
def replace_gemm_with_matmul(self):
|
||||
nodes_to_remove = []
|
||||
nodes_to_add = []
|
||||
new_nodes = []
|
||||
|
||||
for node in self.model.nodes():
|
||||
if node.op_type == 'Gemm':
|
||||
|
|
@ -153,21 +152,48 @@ class ONNXQuantizer:
|
|||
transA = onnx.helper.get_attribute_value(attr)
|
||||
elif attr.name == 'transB':
|
||||
transB = onnx.helper.get_attribute_value(attr)
|
||||
if alpha == 1.0 and beta == 1.0 and transA == 0 and transB == 0:
|
||||
matmul_node = onnx.helper.make_node('MatMul', [node.input[0], node.input[1]],
|
||||
[node.output[0] + '_MatMul'],
|
||||
name=node.output[0] + '_MatMul')
|
||||
if alpha == 1.0 and beta == 1.0 and transA == 0:
|
||||
inputB = node.input[1]
|
||||
if transB == 1:
|
||||
B = self.model.get_initializer(node.input[1])
|
||||
if B:
|
||||
# assume B is not used by any other node
|
||||
B_array = onnx.numpy_helper.to_array(B)
|
||||
B_trans = onnx.numpy_helper.from_array(B_array.T)
|
||||
B_trans.name = B.name
|
||||
self.model.remove_initializer(B)
|
||||
self.model.add_initializer(B_trans)
|
||||
else:
|
||||
inputB += '_Transposed'
|
||||
transpose_node = onnx.helper.make_node('Transpose',
|
||||
inputs=[node.input[1]],
|
||||
outputs=[inputB],
|
||||
name=node.name+'_Transpose')
|
||||
new_nodes.append(transpose_node)
|
||||
|
||||
add_node = onnx.helper.make_node('Add',
|
||||
inputs=[node.output[0] + '_MatMul', node.input[2]],
|
||||
outputs=node.output,
|
||||
name=node.output[0] + '_Add')
|
||||
matmul_node = onnx.helper.make_node('MatMul',
|
||||
inputs=[node.input[0], inputB],
|
||||
outputs=[node.output[0] + ('_MatMul' if len(node.input)>2 else '')],
|
||||
name=node.name + '_MatMul')
|
||||
new_nodes.append(matmul_node)
|
||||
|
||||
nodes_to_add.extend([matmul_node, add_node])
|
||||
nodes_to_remove.extend([node])
|
||||
if len(node.input) > 2:
|
||||
add_node = onnx.helper.make_node('Add',
|
||||
inputs=[node.output[0] + '_MatMul', node.input[2]],
|
||||
outputs=node.output,
|
||||
name=node.name + '_Add')
|
||||
new_nodes.append(add_node)
|
||||
|
||||
# unsupported
|
||||
else:
|
||||
new_nodes.append(node)
|
||||
|
||||
# not GEMM
|
||||
else:
|
||||
new_nodes.append(node)
|
||||
|
||||
self.model.add_nodes(nodes_to_add)
|
||||
self.model.remove_nodes(nodes_to_remove)
|
||||
self.model.graph().ClearField('node')
|
||||
self.model.graph().node.extend(new_nodes)
|
||||
|
||||
def remove_fake_quantized_nodes(self):
|
||||
'''
|
||||
|
|
|
|||
Loading…
Reference in a new issue