onnxruntime/onnxruntime/python/tools/quantization/registry.py
Yufeng Li 6f86c4dbe3
Quantize LSTM (#5595)
Quantize LSTM:
1. dynamically quantizes MatMul inside the LSTM. It doesn't quantize activation function.
2. support per-channel on the input weight and recurrent weight.
2020-11-18 11:21:49 -08:00

45 lines
1.5 KiB
Python

from .quant_utils import QuantizationMode
from .operators.base_operator import QuantOperatorBase
from .operators.matmul import MatMulInteger, QLinearMatMul
from .operators.attention import AttentionQuant
from .operators.embed_layernorm import EmbedLayerNormalizationQuant
from .operators.gather import GatherQuant
from .operators.conv import QLinearConv, ConvInteger
from .operators.activation import QLinearActivation
from .operators.binary_op import QLinearBinaryOp
from .operators.maxpool import QMaxPool
from. operators.lstm import LSTMQuant
CommonOpsRegistry = {"Gather": GatherQuant, "EmbedLayerNormalization": EmbedLayerNormalizationQuant}
IntegerOpsRegistry = {
"Conv": ConvInteger,
"MatMul": MatMulInteger,
"Attention": AttentionQuant,
"LSTM": LSTMQuant,
}
IntegerOpsRegistry.update(CommonOpsRegistry)
QLinearOpsRegistry = {
"Conv": QLinearConv,
"MatMul": QLinearMatMul,
"Add": QLinearBinaryOp,
"Mul": QLinearBinaryOp,
"Relu": QLinearActivation,
"Clip": QLinearActivation,
"LeakyRelu" : QLinearActivation,
"Sigmoid" : QLinearActivation,
"MaxPool": QMaxPool,
}
QLinearOpsRegistry.update(CommonOpsRegistry)
def CreateDefaultOpQuantizer(onnx_quantizer, node):
return QuantOperatorBase(onnx_quantizer, node)
def CreateOpQuantizer(onnx_quantizer, node):
registry = IntegerOpsRegistry if onnx_quantizer.mode == QuantizationMode.IntegerOps else QLinearOpsRegistry
if node.op_type in registry.keys():
return registry[node.op_type](onnx_quantizer, node)
return QuantOperatorBase(onnx_quantizer, node)