mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Merge remote-tracking branch 'origin/master' into gwang-msft/qdq_mul
This commit is contained in:
commit
a8c7dce22f
1 changed files with 12 additions and 0 deletions
|
|
@ -80,6 +80,7 @@ class TestOpArgMax(unittest.TestCase):
|
|||
weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8'
|
||||
model_uint8_path = 'argmax_{}{}.onnx'.format(activation_type_str, weight_type_str)
|
||||
model_uint8_qdq_path = 'argmax_{}{}_qdq.onnx'.format(activation_type_str, weight_type_str)
|
||||
model_uint8_qdq_trt_path = 'argmax_{}{}_qdq_trt.onnx'.format(activation_type_str, weight_type_str)
|
||||
|
||||
# Verify QOperator mode
|
||||
data_reader = self.input_feeds(1, {'input': [1, 256, 128, 128]})
|
||||
|
|
@ -105,6 +106,17 @@ class TestOpArgMax(unittest.TestCase):
|
|||
data_reader.rewind()
|
||||
check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next())
|
||||
|
||||
# Verify QDQ mode for TensorRT
|
||||
data_reader.rewind()
|
||||
quantize_static(model_fp32_path, model_uint8_qdq_trt_path, data_reader, quant_format=QuantFormat.QDQ,
|
||||
activation_type=activation_type, weight_type=weight_type, extra_options=extra_options,
|
||||
op_types_to_quantize=['ArgMax'])
|
||||
qdqnode_counts = {'QuantizeLinear': 1, 'DequantizeLinear': 1, 'ArgMax': 1}
|
||||
check_op_type_count(self, model_uint8_qdq_trt_path, **qdqnode_counts)
|
||||
qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
|
||||
check_qtype_by_node_type(self, model_uint8_qdq_trt_path, qnode_io_qtypes)
|
||||
data_reader.rewind()
|
||||
check_model_correctness(self, model_fp32_path, model_uint8_qdq_trt_path, data_reader.get_next())
|
||||
|
||||
def test_quantize_argmax(self):
|
||||
self.quantize_argmax_test(QuantType.QUInt8, QuantType.QUInt8)
|
||||
|
|
|
|||
Loading…
Reference in a new issue