onnxruntime/onnxruntime/test/testdata/matmul_integer_to_float.py
Yufeng Li 67a7d93b49
Fuse MatMulInteger and scale followed (#4350)
* Fuse MatMulInteger and scale followed

* Add bias
2020-07-02 13:08:21 -07:00

55 lines
No EOL
2.1 KiB
Python

import onnx
from onnx import helper
from onnx import TensorProto
from enum import Enum
def GenerateModel(model_name, sign, has_zp = True, bias = False):
nodes = [ # subgraph
helper.make_node(
"MatMulInteger",
["A", "B", "a_zero_point", "b_zero_point"] if has_zp else ["A", "B"],
["matmul_output_int32"],
"MatMulInteger"),
helper.make_node("Mul", ["a_scale", "b_scale"], ["multiplier"], "mul_right"),
helper.make_node("Cast", ["matmul_output_int32"], ["matmul_output_float"], "cast", to=1),
helper.make_node("Mul", ["matmul_output_float", "multiplier"], ["mul_bottom_output" if bias else "Y"], "mul_bottom"),
]
inputs = [ # inputs
helper.make_tensor_value_info('A', TensorProto.UINT8, ['M', 'K']),
helper.make_tensor_value_info('B', TensorProto.INT8 if sign else TensorProto.UINT8, ['K', 'N']),
helper.make_tensor_value_info('a_scale', TensorProto.FLOAT, [1]),
helper.make_tensor_value_info('b_scale', TensorProto.FLOAT, [1]),
]
if has_zp:
inputs.extend([
helper.make_tensor_value_info('a_zero_point', TensorProto.UINT8, [1]),
helper.make_tensor_value_info('b_zero_point', TensorProto.INT8 if sign else TensorProto.UINT8, [1]),
])
if bias:
nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")])
inputs.extend([helper.make_tensor_value_info('bias', TensorProto.FLOAT, ['N'])])
graph = helper.make_graph(
nodes,
"DynamicQuantizeMatMul_fusion", #name
inputs,
[ # outputs
helper.make_tensor_value_info('Y', TensorProto.FLOAT, ['M', 'N']),
])
model = helper.make_model(graph)
onnx.save(model, model_name)
if __name__ == "__main__":
GenerateModel('matmul_integer_to_float_int8.onnx', True)
GenerateModel('matmul_integer_to_float_uint8.onnx', False)
GenerateModel('matmul_integer_to_float_int8_bias.onnx', True, False, True)
GenerateModel('matmul_integer_to_float_uint8_bias.onnx', False, False, True)