mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
[Nuphar] Parse node doc_string for quantize info (#8746)
This commit is contained in:
parent
47b3ecb53b
commit
6ecf626a9c
1 changed files with 48 additions and 1 deletions
|
|
@ -52,6 +52,53 @@ class QuantizeConfig:
|
|||
('QuantizationType', 'Signed' if self.sign_bit_ else 'Unsigned'),
|
||||
('ReservedBit', self.reserved_bits_)])
|
||||
|
||||
def parse_custom_attributes(in_node):
|
||||
if in_node.doc_string:
|
||||
# some models have node description as qcfg, e.g.
|
||||
# {"custom_attributes":{
|
||||
# "FutureContextLength":10,
|
||||
# "IntermediateBit":32,
|
||||
# "PerRowQuantization":true,
|
||||
# "QuantizeBitOfVector":8,
|
||||
# "VectorQuantizationType":"AsymmetricUnsigned",
|
||||
# "QuantizeBitOfMatrix":8,
|
||||
# "ReservedBitOfVector":1,
|
||||
# "MatrixQuantizationType":"AsymmetricSigned",
|
||||
# "ReservedBitOfMatrix":0}}
|
||||
qcfg_str = in_node.doc_string
|
||||
# make sure it's the string we can parse
|
||||
if 'custom_attributes' in qcfg_str:
|
||||
# some fixes to make it a valid JSON string, when model keys are not string
|
||||
if qcfg_str[1] == 'c':
|
||||
qcfg_str = qcfg_str.replace('{', '{"')
|
||||
qcfg_str = qcfg_str.replace(',', ',"')
|
||||
qcfg_str = qcfg_str.replace(':', '":')
|
||||
qcfg_str = qcfg_str.replace('{"}', '{}')
|
||||
qcfg = json.loads(qcfg_str)['custom_attributes']
|
||||
if qcfg:
|
||||
return qcfg
|
||||
return None
|
||||
|
||||
def parse_node_description(in_node):
|
||||
if not in_node.doc_string:
|
||||
return None
|
||||
custom_qcfg = parse_custom_attributes(in_node)
|
||||
if custom_qcfg:
|
||||
assert custom_qcfg['IntermediateBit'] == 32
|
||||
assert custom_qcfg['PerRowQuantization']
|
||||
assert custom_qcfg['QuantizeBitOfVector'] == custom_qcfg['QuantizeBitOfMatrix']
|
||||
qbits = custom_qcfg['QuantizeBitOfVector']
|
||||
assert ("Asymmetric" in custom_qcfg['VectorQuantizationType']) == ("Asymmetric" in custom_qcfg['MatrixQuantizationType'])
|
||||
symmetric = 0 if "Asymmetric" in custom_qcfg['VectorQuantizationType'] else 1
|
||||
x_signed = 0 if "Unsigned" in custom_qcfg['VectorQuantizationType'] else 1
|
||||
w_signed = 0 if "Unsigned" in custom_qcfg['MatrixQuantizationType'] else 1
|
||||
x_reserved_bits = custom_qcfg['ReservedBitOfVector']
|
||||
w_reserved_bits = custom_qcfg['ReservedBitOfMatrix']
|
||||
return {'W' : dict(QuantizeConfig(signed=w_signed, reserved_bits=w_reserved_bits, type_bits=qbits)),
|
||||
'X' : dict(QuantizeConfig(signed=x_signed, reserved_bits=x_reserved_bits, type_bits=qbits)),
|
||||
'Symmetric' : symmetric}
|
||||
return None
|
||||
|
||||
def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, quantized_inputs, qcfg_dict, update_qcfg_dict, default_qcfg, onnx_opset_ver):
|
||||
assert in_node.op_type == 'MatMul'
|
||||
|
||||
|
|
@ -68,7 +115,7 @@ def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, qua
|
|||
if in_node.output[0] in qcfg_dict:
|
||||
node_qcfg = qcfg_dict[in_node.output[0]]
|
||||
else:
|
||||
node_qcfg = None
|
||||
node_qcfg = parse_node_description(in_node)
|
||||
if not node_qcfg:
|
||||
if not update_qcfg_dict and qcfg_dict:
|
||||
# when qcfg_dict is readonly, raise warning if qcfg is not found for this node
|
||||
|
|
|
|||
Loading…
Reference in a new issue