diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index c493abbb35..2b440390db 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -348,7 +348,7 @@ class ONNXQuantizer: new_list += self._quantize_convolution(node, new_list) elif node.op_type == 'MatMul': new_list += self._quantize_matmul(node, new_list) - elif node.op_type == 'Gather' and self._is_valid_quantize_value(node.input[0]): + elif node.op_type == 'Gather' and self._is_valid_initializer_value(node.input[0]): new_list += self._quantize_gather_ops(node, new_list) elif node.op_type == 'Add' or node.op_type == 'Mul': new_list += self._quantize_binary_math_ops(node, new_list) @@ -395,6 +395,9 @@ class ONNXQuantizer: value_info = self.value_infos[value_name] return value_info.type.HasField( 'tensor_type') and value_info.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT + return self._is_valid_initializer_value(value_name) + + def _is_valid_initializer_value(self, value_name): weight = _find_by_name(value_name, self.model.graph.initializer) return weight is not None and weight.data_type == onnx_proto.TensorProto.FLOAT