[Nuphar] Parse node doc_string for quantize info (#8746)

This commit is contained in:
KeDengMS 2021-08-17 11:29:03 -07:00 committed by GitHub
parent 47b3ecb53b
commit 6ecf626a9c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -52,6 +52,53 @@ class QuantizeConfig:
('QuantizationType', 'Signed' if self.sign_bit_ else 'Unsigned'),
('ReservedBit', self.reserved_bits_)])
def parse_custom_attributes(in_node):
if in_node.doc_string:
# some models have node description as qcfg, e.g.
# {"custom_attributes":{
# "FutureContextLength":10,
# "IntermediateBit":32,
# "PerRowQuantization":true,
# "QuantizeBitOfVector":8,
# "VectorQuantizationType":"AsymmetricUnsigned",
# "QuantizeBitOfMatrix":8,
# "ReservedBitOfVector":1,
# "MatrixQuantizationType":"AsymmetricSigned",
# "ReservedBitOfMatrix":0}}
qcfg_str = in_node.doc_string
# make sure it's the string we can parse
if 'custom_attributes' in qcfg_str:
# some fixes to make it a valid JSON string, when model keys are not string
if qcfg_str[1] == 'c':
qcfg_str = qcfg_str.replace('{', '{"')
qcfg_str = qcfg_str.replace(',', ',"')
qcfg_str = qcfg_str.replace(':', '":')
qcfg_str = qcfg_str.replace('{"}', '{}')
qcfg = json.loads(qcfg_str)['custom_attributes']
if qcfg:
return qcfg
return None
def parse_node_description(in_node):
if not in_node.doc_string:
return None
custom_qcfg = parse_custom_attributes(in_node)
if custom_qcfg:
assert custom_qcfg['IntermediateBit'] == 32
assert custom_qcfg['PerRowQuantization']
assert custom_qcfg['QuantizeBitOfVector'] == custom_qcfg['QuantizeBitOfMatrix']
qbits = custom_qcfg['QuantizeBitOfVector']
assert ("Asymmetric" in custom_qcfg['VectorQuantizationType']) == ("Asymmetric" in custom_qcfg['MatrixQuantizationType'])
symmetric = 0 if "Asymmetric" in custom_qcfg['VectorQuantizationType'] else 1
x_signed = 0 if "Unsigned" in custom_qcfg['VectorQuantizationType'] else 1
w_signed = 0 if "Unsigned" in custom_qcfg['MatrixQuantizationType'] else 1
x_reserved_bits = custom_qcfg['ReservedBitOfVector']
w_reserved_bits = custom_qcfg['ReservedBitOfMatrix']
return {'W' : dict(QuantizeConfig(signed=w_signed, reserved_bits=w_reserved_bits, type_bits=qbits)),
'X' : dict(QuantizeConfig(signed=x_signed, reserved_bits=x_reserved_bits, type_bits=qbits)),
'Symmetric' : symmetric}
return None
def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, quantized_inputs, qcfg_dict, update_qcfg_dict, default_qcfg, onnx_opset_ver):
assert in_node.op_type == 'MatMul'
@ -68,7 +115,7 @@ def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, qua
if in_node.output[0] in qcfg_dict:
node_qcfg = qcfg_dict[in_node.output[0]]
else:
node_qcfg = None
node_qcfg = parse_node_description(in_node)
if not node_qcfg:
if not update_qcfg_dict and qcfg_dict:
# when qcfg_dict is readonly, raise warning if qcfg is not found for this node