diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 7fe8f7fd1e..034e023d8b 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -320,7 +320,11 @@ class ONNXQuantizer: def is_valid_quantize_weight(self, weight_name): weight = find_by_name(weight_name, self.model.initializer()) - return weight is not None and weight.data_type == onnx_proto.TensorProto.FLOAT + if weight is not None: + return weight.data_type == onnx_proto.TensorProto.FLOAT + if (not self.enable_subgraph_quantization) or (self.parent is None): + return False + return self.parent.is_valid_quantize_weight(weight_name) def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType): '''