From 6ecf626a9c491caffe3bd481d638b27d14534a6f Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Tue, 17 Aug 2021 11:29:03 -0700 Subject: [PATCH] [Nuphar] Parse node doc_string for quantize info (#8746) --- .../nuphar/scripts/model_quantizer.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/nuphar/scripts/model_quantizer.py b/onnxruntime/core/providers/nuphar/scripts/model_quantizer.py index f7fdd60797..4c1f84d34e 100644 --- a/onnxruntime/core/providers/nuphar/scripts/model_quantizer.py +++ b/onnxruntime/core/providers/nuphar/scripts/model_quantizer.py @@ -52,6 +52,53 @@ class QuantizeConfig: ('QuantizationType', 'Signed' if self.sign_bit_ else 'Unsigned'), ('ReservedBit', self.reserved_bits_)]) +def parse_custom_attributes(in_node): + if in_node.doc_string: + # some models have node description as qcfg, e.g. + # {"custom_attributes":{ + # "FutureContextLength":10, + # "IntermediateBit":32, + # "PerRowQuantization":true, + # "QuantizeBitOfVector":8, + # "VectorQuantizationType":"AsymmetricUnsigned", + # "QuantizeBitOfMatrix":8, + # "ReservedBitOfVector":1, + # "MatrixQuantizationType":"AsymmetricSigned", + # "ReservedBitOfMatrix":0}} + qcfg_str = in_node.doc_string + # make sure it's the string we can parse + if 'custom_attributes' in qcfg_str: + # some fixes to make it a valid JSON string, when model keys are not string + if qcfg_str[1] == 'c': + qcfg_str = qcfg_str.replace('{', '{"') + qcfg_str = qcfg_str.replace(',', ',"') + qcfg_str = qcfg_str.replace(':', '":') + qcfg_str = qcfg_str.replace('{"}', '{}') + qcfg = json.loads(qcfg_str)['custom_attributes'] + if qcfg: + return qcfg + return None + +def parse_node_description(in_node): + if not in_node.doc_string: + return None + custom_qcfg = parse_custom_attributes(in_node) + if custom_qcfg: + assert custom_qcfg['IntermediateBit'] == 32 + assert custom_qcfg['PerRowQuantization'] + assert custom_qcfg['QuantizeBitOfVector'] == custom_qcfg['QuantizeBitOfMatrix'] + qbits = custom_qcfg['QuantizeBitOfVector'] + assert ("Asymmetric" in custom_qcfg['VectorQuantizationType']) == ("Asymmetric" in custom_qcfg['MatrixQuantizationType']) + symmetric = 0 if "Asymmetric" in custom_qcfg['VectorQuantizationType'] else 1 + x_signed = 0 if "Unsigned" in custom_qcfg['VectorQuantizationType'] else 1 + w_signed = 0 if "Unsigned" in custom_qcfg['MatrixQuantizationType'] else 1 + x_reserved_bits = custom_qcfg['ReservedBitOfVector'] + w_reserved_bits = custom_qcfg['ReservedBitOfMatrix'] + return {'W' : dict(QuantizeConfig(signed=w_signed, reserved_bits=w_reserved_bits, type_bits=qbits)), + 'X' : dict(QuantizeConfig(signed=x_signed, reserved_bits=x_reserved_bits, type_bits=qbits)), + 'Symmetric' : symmetric} + return None + def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, quantized_inputs, qcfg_dict, update_qcfg_dict, default_qcfg, onnx_opset_ver): assert in_node.op_type == 'MatMul' @@ -68,7 +115,7 @@ def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, qua if in_node.output[0] in qcfg_dict: node_qcfg = qcfg_dict[in_node.output[0]] else: - node_qcfg = None + node_qcfg = parse_node_description(in_node) if not node_qcfg: if not update_qcfg_dict and qcfg_dict: # when qcfg_dict is readonly, raise warning if qcfg is not found for this node