From 498483b464ee454c8cd8641d2c005d970e330651 Mon Sep 17 00:00:00 2001 From: Zhang Lei Date: Wed, 16 Sep 2020 22:52:24 -0700 Subject: [PATCH] MaxPool versioning in quantization tools. (#5194) MaxPool versioning in quantization tools. --- onnxruntime/python/tools/quantization/onnx_quantizer.py | 6 ++++-- onnxruntime/python/tools/quantization/operators/maxpool.py | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index d54e0be71b..8514058ff5 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -93,7 +93,7 @@ class ONNXQuantizer: self.op_types_to_quantize = op_types_to_quantize self.new_nodes = [] - self.check_opset_version() + self.opset_version = self.check_opset_version() if not self.mode in quantization_modes: raise ValueError('unsupported quantization mode {}'.format(self.mode)) @@ -124,7 +124,7 @@ class ONNXQuantizer: print( "Warning: The original model opset version is {}, which does not support node fusions. Please update the model to opset >= 11 for better performance." .format(opset_version)) - return + return 10 if opset_version < 10: print( @@ -132,8 +132,10 @@ class ONNXQuantizer: .format(opset_version)) self.model.model.opset_import.remove(ai_onnx_domain[0]) self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 11)]) + opset_version = 11 self.fuse_dynamic_quant = True + return opset_version def replace_gemm_with_matmul(self): nodes_to_remove = [] diff --git a/onnxruntime/python/tools/quantization/operators/maxpool.py b/onnxruntime/python/tools/quantization/operators/maxpool.py index 9ba993f2a2..8cc1b7da7a 100644 --- a/onnxruntime/python/tools/quantization/operators/maxpool.py +++ b/onnxruntime/python/tools/quantization/operators/maxpool.py @@ -12,6 +12,10 @@ class QMaxPool(QuantOperatorBase): node = self.node assert (node.op_type == "MaxPool") + if self.quantizer.opset_version < 12: + super().quantize() + return + # When mode is QLinearOps, the output quantization params are calculated based on outputs from # activation nodes, therefore these nodes can be removed from the graph if they follow a quantized op. # If input to this node is not quantized then keep this node