mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
[QNN QDQ Quant] Utils to generate mixed-precision quant overrides (#20028)
### Description - Adds a utility to the QNN quantization scripts that "fixes" an initial set of tensor quantization overrides for mixed-precision QDQ models. Follow-up to https://github.com/microsoft/onnxruntime/pull/19925 - Moves existing overrides for QNN compatibility (matmul, layernorm, sigmoid, tanh) to separate functions. PR adds missing unit tests for these. - Adds `weight_symmetric=None` parameter to the `get_qnn_qdq_config()` function to enable user specification (instead of always using default behavior). - If weight_symmetric is set to `None`, it will be set to `weight_symmetric = weight_type in (QUInt8, QUInt16)`. - Otherwise, the user's value is used. #### Example Float model: ``` input_0 --> Op1 --> Op3 --> Op5 --> Op6 --> output_0 ^ | input_1 --> Op2 -+-> Op4 ----+ | +-> Op7 --> output_1 | +-> Op8 --> output_2 ``` If we'd like to quantize this model to uint8 precision, but would like to make sure tensor "Op4_out" is quantized to 16-bit, then we would specify the following initial tensor quantization overrides: ```python # Op4_out could be an inaccurate tensor that should be upgraded to 16bit initial_overrides = {"Op4_out": [{"quant_type": QuantType.QUInt16}]} ``` These initial overrides may not create a valid model because Op4 and Op5 may require both the input and output to be the same type (e.g., uint16). This helper fixes the overrides so that input/output data types are valid: ```python qnn_config = get_qnn_qdq_config( float_model_path, data_reader, activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8, init_overrides=initial_overrides, # These initial overrides will be "fixed" ) ``` The above snippet generates the following "fixed" overrides (get via `qnn_config.extra_options["TensorQuantOverrides"]`): ```python { "Op2_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op4"}}}], "Op3_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op5"}}}], "Op4_out": [{"quant_type": QUInt16}], "Op5_out": [{"quant_type": QUInt16, "convert": {"quant_type": QUInt8, "recv_nodes": {"Op6"}}}] } ``` How to interpret the fixed overrides: - Op2's output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type, but Op7 and Op8 consume the original u8 type. - Op3's output is converted from u8 to u16. Op5 consumes the converted u16 type. - Op4's output is just u16 (not converted). All consumers of Op4_out get the u16 type. - Op5's output is converted from u16 to u8. Op6 consumes the u8 type. ### Motivation and Context Generating mixed-precision quantization overrides is currently a manual process. This PR adds an utility that helps generate valid overrides.
This commit is contained in:
parent
d30c81d270
commit
7d976cf720
5 changed files with 1371 additions and 66 deletions
|
|
@ -0,0 +1,413 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import onnx
|
||||
|
||||
from ...quant_utils import QuantType
|
||||
from ...tensor_quant_overrides import QuantTypeInfo, TensorQuantOverridesHelper
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorTypeRequest:
|
||||
"""
|
||||
Bundles desired quantization type requests for a tensor. A distinction is made between the
|
||||
produced type and the consumed type.
|
||||
"""
|
||||
|
||||
# The tensor's quant type at the producer end. If None, assumed to be the default activation quant type.
|
||||
producer: QuantTypeInfo | None
|
||||
|
||||
# The tensor's quant type received by a set of consumer nodes.
|
||||
# If None, assumed to be the default activation quant type for all consumers.
|
||||
# consumers[1] is a set of consumer node names.
|
||||
consumers: tuple[QuantTypeInfo, set[str]] | None
|
||||
|
||||
|
||||
class MixedPrecisionTensorQuantOverridesFixer:
|
||||
"""
|
||||
Helper that generates tensor quantization overrides for mixed-precision QDQ models.
|
||||
|
||||
Specifically, this helper fixes an initial set of quantization overrides that assign a non-default
|
||||
activation quantization type to one or more tensors by doing the following:
|
||||
- Inferring which other tensors need to be overridden to the non-default activation quantization type.
|
||||
- Inserting quantization data type conversions.
|
||||
|
||||
Example:
|
||||
--------
|
||||
|
||||
Float model:
|
||||
|
||||
input_0 --> Op1 --> Op3 --> Op5 --> Op6 --> output_0
|
||||
^
|
||||
|
|
||||
input_1 --> Op2 -+-> Op4 ----+
|
||||
|
|
||||
+-> Op7 --> output_1
|
||||
|
|
||||
+-> Op8 --> output_2
|
||||
|
||||
If we'd like to quantize this model to uint8 precision, but would like to make sure tensor "Op4_out"
|
||||
is quantized to 16-bit, then we would specify the following initial tensor quantization overrides:
|
||||
|
||||
```
|
||||
init_overrides = {"Op4_out": [{"quant_type": QuantType.QUInt16}]}
|
||||
```
|
||||
|
||||
These initial overrides may not create a valid model because Op4 and Op5 may require both the input and output
|
||||
to be the same type (e.g., uint16). This helper fixes the overrides so that input/output data types
|
||||
are valid:
|
||||
|
||||
```
|
||||
overrides = TensorQuantOverridesHelper(init_overrides)
|
||||
|
||||
fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, QuantType.QUInt8)
|
||||
fixer.apply(
|
||||
default_activation_qtype=QuantType.QUInt8,
|
||||
default_activation_symmetric=False,
|
||||
)
|
||||
```
|
||||
|
||||
The above snippet generates the following "fixed" overrides (get via overrides.get_dict()):
|
||||
|
||||
{
|
||||
"Op2_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op4"}}}],
|
||||
"Op3_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op5"}}}],
|
||||
"Op4_out": [{"quant_type": QUInt16}],
|
||||
"Op5_out": [{"quant_type": QUInt16, "convert": {"quant_type": QUInt8, "recv_nodes": {"Op6"}}}]
|
||||
}
|
||||
|
||||
How to interpret the fixed overrides:
|
||||
- Op2's output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type,
|
||||
but Op7 and Op8 consume the original u8 type.
|
||||
- Op3's output is converted from u8 to u16. Op5 consumes the converted u16 type.
|
||||
- Op4's output is just u16 (not converted). All consumers of Op4_out get the u16 type.
|
||||
- Op5's output is converted from u16 to u8. Op6 consumes the u8 type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
overrides: TensorQuantOverridesHelper,
|
||||
producers: dict[str, onnx.NodeProto],
|
||||
consumers: dict[str, list[onnx.NodeProto]],
|
||||
value_infos: dict[str, onnx.ValueInfoProto],
|
||||
initializers: dict[str, onnx.TensorProto],
|
||||
):
|
||||
"""
|
||||
Params:
|
||||
overrides: The initial tensor quantization overrides to fix.
|
||||
producers: Dictionary that maps a tensor name to the producer node that generates the tensor.
|
||||
consumers: Dictionary that maps a tensor name to the consumer nodes that take the tensor as input.
|
||||
value_infos: Dictionary that maps a tensor name to its onnx.ValueInfoProto.
|
||||
initializers: Dictionary that maps an initializer name to its onnx.TensorProto.
|
||||
"""
|
||||
self.overrides = overrides
|
||||
self.consumers = consumers
|
||||
self.producers = producers
|
||||
self.value_infos = value_infos
|
||||
self.initializers = initializers
|
||||
|
||||
@staticmethod
|
||||
def create_from_model(
|
||||
overrides: TensorQuantOverridesHelper, model: onnx.ModelProto, default_activation_qtype: QuantType
|
||||
) -> MixedPrecisionTensorQuantOverridesFixer:
|
||||
"""
|
||||
Helper function that creates an instance of this class from a loaded ONNX model.
|
||||
|
||||
Params:
|
||||
overrides: The initial tensor quantization overrides to fix.
|
||||
model: Loaded ONNX model
|
||||
default_activation_qtype: The intended default activation quantization type.
|
||||
Used to validate the initial overrides.
|
||||
|
||||
Returns:
|
||||
Initialized MixedPrecisionTensorQuantOverridesFixer object
|
||||
"""
|
||||
model = onnx.shape_inference.infer_shapes(model) # Need to infer shapes to get value_infos
|
||||
|
||||
# Build dictionaries that enable convenient lookups of initializers and value_infos by name.
|
||||
initializers = {initializer.name: initializer for initializer in model.graph.initializer}
|
||||
value_infos = {vi.name: vi for vi in model.graph.value_info}
|
||||
value_infos.update({ot.name: ot for ot in model.graph.output})
|
||||
value_infos.update({it.name: it for it in model.graph.input})
|
||||
|
||||
# Ensure that the user-provided initial overrides are actually valid.
|
||||
valid, err = overrides.is_valid(set(initializers), set(value_infos), default_activation_qtype)
|
||||
if not valid:
|
||||
pprint_overrides = overrides.pprint_str(indent=4)
|
||||
logging.error(f"Provided invalid tensor quantization overrides:\n{pprint_overrides}")
|
||||
raise ValueError(err)
|
||||
|
||||
consumers = {}
|
||||
producers = {}
|
||||
|
||||
# Build dictionaries that map a tensor name to the consumer or producer nodes.
|
||||
for node in model.graph.node:
|
||||
for input_name in node.input:
|
||||
if input_name:
|
||||
if input_name not in consumers:
|
||||
consumers[input_name] = []
|
||||
|
||||
consumers[input_name].append(node)
|
||||
|
||||
for output_name in node.output:
|
||||
producers[output_name] = node
|
||||
|
||||
return MixedPrecisionTensorQuantOverridesFixer(overrides, producers, consumers, value_infos, initializers)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
default_activation_qtype: QuantType,
|
||||
default_activation_symmetric: bool,
|
||||
):
|
||||
"""
|
||||
Fixes the initial tensor quantization overrides (in-place) for use in mixed-precision QDQ models.
|
||||
|
||||
Params:
|
||||
default_activation_qtype: The intended default activation quantization type.
|
||||
default_activation_symmetric: The intended default symmetry used to quantize activations.
|
||||
"""
|
||||
type_requests = self.get_desired_tensor_types(default_activation_qtype, default_activation_symmetric)
|
||||
|
||||
# Use type requests to "fix" tensor quantization overrides by adding
|
||||
# quantization type conversions where necessary.
|
||||
for tensor_name, type_req in type_requests.items():
|
||||
all_consumers = set([node.name for node in self.consumers.get(tensor_name, [])])
|
||||
has_producer_req = type_req.producer is not None
|
||||
has_consumer_req = bool(type_req.consumers)
|
||||
|
||||
# Only producer type: Add conversion back to default activation type
|
||||
if has_producer_req and not has_consumer_req:
|
||||
self._update_converted_tensor(
|
||||
tensor_name, type_req.producer, QuantTypeInfo(default_activation_qtype), all_consumers
|
||||
)
|
||||
# Only consumers
|
||||
elif not has_producer_req and has_consumer_req:
|
||||
prod_type_info = self.overrides.get_node_output_qtype_info(tensor_name, default_activation_qtype)
|
||||
consumer_type_info = type_req.consumers[0]
|
||||
|
||||
if prod_type_info != consumer_type_info:
|
||||
self._update_converted_tensor(
|
||||
tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1]
|
||||
)
|
||||
else:
|
||||
if not self._check_nodes_are_not_convert_consumers(tensor_name, type_req.consumers[1]):
|
||||
raise ValueError(
|
||||
f"Tensor override for '{tensor_name}' converts the type for consumers that need the original type."
|
||||
)
|
||||
# Both producer and consumers
|
||||
elif has_producer_req and has_consumer_req:
|
||||
prod_type_info = type_req.producer
|
||||
consumer_type_info = type_req.consumers[0]
|
||||
|
||||
if prod_type_info != consumer_type_info:
|
||||
self._update_converted_tensor(
|
||||
tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1]
|
||||
)
|
||||
else:
|
||||
consumers_for_original_type = all_consumers.difference(type_req.consumers[1])
|
||||
|
||||
if len(consumers_for_original_type) == 0:
|
||||
# All consumers want the overridden type, so no need for convert nodes!
|
||||
# Just add the override to the new new if not already present.
|
||||
if tensor_name not in self.overrides:
|
||||
self.overrides[tensor_name] = [{}]
|
||||
prod_type_info.save_to_dict(self.overrides[tensor_name][0])
|
||||
|
||||
assert "convert" not in self.overrides[tensor_name][0]
|
||||
else:
|
||||
# Some consumers don't want the overridden type.
|
||||
self._update_converted_tensor(
|
||||
tensor_name,
|
||||
prod_type_info,
|
||||
QuantTypeInfo(default_activation_qtype),
|
||||
consumers_for_original_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"TypeRequest for tensor {tensor_name} has no producer or consumers.")
|
||||
|
||||
# Done. Check if the overrides are valid.
|
||||
valid, err = self.overrides.is_valid(set(self.initializers), set(self.value_infos), default_activation_qtype)
|
||||
if not valid:
|
||||
pprint_overrides = self.overrides.pprint_str(indent=4)
|
||||
logging.error(
|
||||
f"Generated invalid tensor quantization overrides for mixed-precision QDQ model:\n{pprint_overrides}"
|
||||
)
|
||||
raise ValueError(err)
|
||||
|
||||
def get_desired_tensor_types(
|
||||
self,
|
||||
default_activation_qtype: QuantType,
|
||||
default_activation_symmetric: bool,
|
||||
) -> dict[str, TensorTypeRequest]:
|
||||
"""
|
||||
Iterates through the initial tensor quantization overrides and builds a set of TensorTypeRequests objects
|
||||
that describe the quantization types required at each tensor. These TensorTypeRequests objects are ultimately
|
||||
used to generated the "fixed" overrides.
|
||||
|
||||
Params:
|
||||
default_activation_qtype: The intended default activation quantization type.
|
||||
default_activation_symmetric: The intended default symmetry used to quantize activations.
|
||||
|
||||
Returns:
|
||||
TensorTypeRequest objects as a dict that maps a tensor name to its requested types.
|
||||
"""
|
||||
type_requests = {}
|
||||
default_activation_type_info = QuantTypeInfo(default_activation_qtype, default_activation_symmetric)
|
||||
|
||||
# Scan tensor overrides for type conversion requests.
|
||||
for tensor_name, override_list in self.overrides.items():
|
||||
if not self.__is_tensor_quantizable(tensor_name):
|
||||
continue # Skip non-quantizable tensors (e.g., not a float)
|
||||
|
||||
if tensor_name in self.initializers:
|
||||
continue # Skip initializers
|
||||
|
||||
if not override_list or len(override_list) > 1:
|
||||
continue # Skip per-channel stuff
|
||||
|
||||
override_dict = override_list[0]
|
||||
quant_type_info = QuantTypeInfo.load_from_dict(override_dict, default_activation_type_info.quant_type)
|
||||
producer_node = self.producers.get(tensor_name) # None if this is a model input
|
||||
|
||||
if quant_type_info != default_activation_type_info and "convert" not in override_dict:
|
||||
if producer_node is not None:
|
||||
self._add_type_requests_for_node(type_requests, quant_type_info, producer_node)
|
||||
|
||||
# Find all consumer nodes of `tensor_name` and update their inputs/outputs to the new type.
|
||||
for consumer_node in self.consumers.get(tensor_name, []):
|
||||
self._add_type_requests_for_node(type_requests, quant_type_info, consumer_node)
|
||||
|
||||
return type_requests
|
||||
|
||||
def _add_type_requests_for_node(
|
||||
self,
|
||||
type_requests: dict[str, TensorTypeRequest],
|
||||
quant_type_info: QuantTypeInfo,
|
||||
node: onnx.NodeProto,
|
||||
):
|
||||
"""
|
||||
Adds TensorTypeRequest objects for a given node, assuming that we want all its inputs and outputs
|
||||
to have the same quantization type (as specified by the `quant_type_info` parameter).
|
||||
|
||||
Params:
|
||||
type_requests: Dictionary of type requests to append to for this node.
|
||||
quant_type_info: The quantization type to use for inputs and outputs.
|
||||
node: The node for which the TensorTypeRequest objects are created and added to type_requests.
|
||||
"""
|
||||
# Add output side
|
||||
for output_name in node.output:
|
||||
if not self.__is_tensor_quantizable(output_name):
|
||||
continue
|
||||
|
||||
if output_name not in type_requests:
|
||||
type_requests[output_name] = TensorTypeRequest(quant_type_info, None)
|
||||
else:
|
||||
if (
|
||||
type_requests[output_name].producer is not None
|
||||
and type_requests[output_name].producer != quant_type_info
|
||||
):
|
||||
raise ValueError(f"Tensor {output_name} has multiple types.")
|
||||
|
||||
type_requests[output_name].producer = quant_type_info
|
||||
|
||||
# Add the consumer side
|
||||
for input_name in node.input:
|
||||
if input_name and input_name not in self.initializers and self.__is_tensor_quantizable(input_name):
|
||||
if input_name not in type_requests:
|
||||
type_requests[input_name] = TensorTypeRequest(None, None)
|
||||
|
||||
if type_requests[input_name].consumers is None:
|
||||
type_requests[input_name].consumers = (quant_type_info, set())
|
||||
|
||||
if type_requests[input_name].consumers[0] != quant_type_info:
|
||||
raise ValueError(f"Tensor {input_name} has consumers requesting different types.")
|
||||
|
||||
if not node.name:
|
||||
raise ValueError(
|
||||
f"Node of type {node.op_type} with output 0 {node.output[0]} does not have a name!"
|
||||
)
|
||||
|
||||
type_requests[input_name].consumers[1].add(node.name)
|
||||
|
||||
def _update_converted_tensor(
|
||||
self,
|
||||
tensor_name: str,
|
||||
producer_type_info: QuantTypeInfo,
|
||||
consumer_type_info: QuantTypeInfo,
|
||||
consumer_names: set[str],
|
||||
):
|
||||
"""
|
||||
Updates the tensor quantization overrides for a tensor that is converted from one type to another.
|
||||
|
||||
Params:
|
||||
tensor_name: The name of the tensor for which to update overrides.
|
||||
producer_type_info: Info for the tensor's produced type.
|
||||
consumer_type_info: Info for the tensor's consumed (i.e., converted) type.
|
||||
consumer_names: Nodes names of consumers that consume the converted type.
|
||||
"""
|
||||
if tensor_name not in self.overrides or not self.overrides[tensor_name]:
|
||||
self.overrides[tensor_name] = [{}]
|
||||
producer_type_info.save_to_dict(self.overrides[tensor_name][0])
|
||||
|
||||
overrides = self.overrides[tensor_name][0]
|
||||
if producer_type_info != QuantTypeInfo.load_from_dict(overrides):
|
||||
raise ValueError(f"Desired producer quant_type for {tensor_name} doesn't match existing type.")
|
||||
|
||||
if consumer_names:
|
||||
if "convert" not in overrides:
|
||||
overrides["convert"] = {}
|
||||
consumer_type_info.save_to_dict(overrides["convert"])
|
||||
|
||||
convert_dict = overrides["convert"]
|
||||
if consumer_type_info != QuantTypeInfo.load_from_dict(convert_dict):
|
||||
raise ValueError(f"Desired consumer quant_type for {tensor_name} doesn't match existing type.")
|
||||
|
||||
if "recv_nodes" not in convert_dict:
|
||||
convert_dict["recv_nodes"] = set()
|
||||
|
||||
convert_dict["recv_nodes"].update(consumer_names)
|
||||
|
||||
def _check_nodes_are_not_convert_consumers(self, tensor_name: str, node_names: set[str]):
|
||||
"""
|
||||
Returns true if the given nodes do not consume/receive a converted quantization type.
|
||||
|
||||
Params:
|
||||
tensor_name: The name of the tensor to check.
|
||||
node_names: Set of node names that should not be consumers of the converted type.
|
||||
"""
|
||||
if tensor_name not in self.overrides or not self.overrides[tensor_name]:
|
||||
return True
|
||||
|
||||
overrides = self.overrides[tensor_name][0]
|
||||
|
||||
if "convert" not in overrides:
|
||||
return True
|
||||
|
||||
convert_dict = overrides["convert"]
|
||||
|
||||
if "recv_nodes" not in convert_dict:
|
||||
return False
|
||||
|
||||
return not convert_dict["recv_nodes"].intersection(node_names)
|
||||
|
||||
def __is_tensor_quantizable(self, tensor_name):
|
||||
weight = self.initializers.get(tensor_name)
|
||||
if weight is not None:
|
||||
if weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16):
|
||||
return True
|
||||
elif tensor_name in self.value_infos:
|
||||
vi = self.value_infos[tensor_name]
|
||||
if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
|
||||
onnx.TensorProto.FLOAT,
|
||||
onnx.TensorProto.FLOAT16,
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
@ -3,6 +3,10 @@
|
|||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -11,6 +15,8 @@ import onnx
|
|||
from ...calibrate import CalibrationDataReader, CalibrationMethod
|
||||
from ...quant_utils import QuantType
|
||||
from ...quantize import StaticQuantConfig
|
||||
from ...tensor_quant_overrides import TensorQuantOverridesHelper
|
||||
from .mixed_precision_overrides_utils import MixedPrecisionTensorQuantOverridesFixer
|
||||
|
||||
Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16}
|
||||
Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8}
|
||||
|
|
@ -18,6 +24,20 @@ OP_TYPES_TO_EXCLUDE = {"Cast"}
|
|||
MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB
|
||||
|
||||
|
||||
def warn_unable_to_override(
|
||||
node: onnx.NodeProto,
|
||||
what_str: str,
|
||||
tensor_name: str,
|
||||
io_kind: str,
|
||||
):
|
||||
logging.warning(
|
||||
f"Unable to override {what_str} for {node.op_type} node's {io_kind} "
|
||||
"because it has already been overridden! Check the initial quantization overrides provided "
|
||||
"to get_qnn_qdq_config() if the generated QDQ model does not run on QNN EP. "
|
||||
f"Node name: {node.name}, {io_kind} name: {tensor_name}"
|
||||
)
|
||||
|
||||
|
||||
def get_qnn_qdq_config(
|
||||
model_input: Path,
|
||||
calibration_data_reader: CalibrationDataReader,
|
||||
|
|
@ -25,14 +45,20 @@ def get_qnn_qdq_config(
|
|||
activation_type=QuantType.QUInt8,
|
||||
weight_type=QuantType.QUInt8,
|
||||
per_channel=False,
|
||||
init_overrides=None,
|
||||
add_qtype_converts=True,
|
||||
activation_symmetric=False,
|
||||
weight_symmetric=None,
|
||||
):
|
||||
if per_channel:
|
||||
raise ValueError("QNN EP does not yet support per-channel quantization.")
|
||||
|
||||
if weight_symmetric is None:
|
||||
weight_symmetric = weight_type in {QuantType.QInt8, QuantType.QInt16}
|
||||
|
||||
model = onnx.load_model(model_input, load_external_data=False)
|
||||
|
||||
op_types = set()
|
||||
tensor_quant_overrides = {}
|
||||
model_has_external_data = False
|
||||
name_to_initializer = {}
|
||||
|
||||
|
|
@ -43,52 +69,40 @@ def get_qnn_qdq_config(
|
|||
if onnx.external_data_helper.uses_external_data(initializer):
|
||||
model_has_external_data = True
|
||||
|
||||
# Setup quantization overrides for specific operator types
|
||||
overrides_helper = TensorQuantOverridesHelper(copy.deepcopy(init_overrides) if init_overrides else {})
|
||||
|
||||
if not overrides_helper.empty() and add_qtype_converts:
|
||||
# Fix mixed-precision overrides.
|
||||
overrides_fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(
|
||||
overrides_helper, model, activation_type
|
||||
)
|
||||
overrides_fixer.apply(activation_type, activation_symmetric)
|
||||
|
||||
# Setup quantization overrides for specific operator types to ensure compatibility with QNN EP.
|
||||
qnn_compat = QnnCompatibilityOverrides(
|
||||
activation_type,
|
||||
weight_type,
|
||||
activation_symmetric,
|
||||
weight_symmetric,
|
||||
overrides_helper,
|
||||
name_to_initializer,
|
||||
)
|
||||
|
||||
for node in model.graph.node:
|
||||
op_types.add(node.op_type)
|
||||
|
||||
if node.op_type == "MatMul" and activation_type in Q16_TYPES and weight_type in Q8_TYPES:
|
||||
weight_symmetric = weight_type == QuantType.QInt8
|
||||
|
||||
# Override initializers to use the weight_type
|
||||
for input_name in node.input:
|
||||
if input_name in name_to_initializer:
|
||||
tensor_quant_overrides[input_name] = [{"quant_type": weight_type, "symmetric": weight_symmetric}]
|
||||
elif node.op_type == "LayerNormalization" and activation_type in Q16_TYPES and weight_type in Q8_TYPES:
|
||||
weight_symmetric = weight_type == QuantType.QInt8
|
||||
|
||||
# Override initializers to use the weight_type. Don't override the bias input.
|
||||
for i in range(2):
|
||||
input_name = node.input[i]
|
||||
if input_name in name_to_initializer:
|
||||
tensor_quant_overrides[input_name] = [{"quant_type": weight_type, "symmetric": weight_symmetric}]
|
||||
elif node.op_type == "Sigmoid":
|
||||
if activation_type == QuantType.QUInt16:
|
||||
tensor_quant_overrides[node.output[0]] = [
|
||||
{"scale": np.array(1.0 / 65536.0, dtype=np.float32), "zero_point": np.array(0, dtype=np.uint16)}
|
||||
]
|
||||
elif activation_type == QuantType.QInt16:
|
||||
tensor_quant_overrides[node.output[0]] = [
|
||||
{"scale": np.array(1.0 / 32768.0, dtype=np.float32), "zero_point": np.array(0, dtype=np.int16)}
|
||||
]
|
||||
elif node.op_type == "Tanh":
|
||||
if activation_type == QuantType.QUInt16:
|
||||
tensor_quant_overrides[node.output[0]] = [
|
||||
{"scale": np.array(1.0 / 32768.0, dtype=np.float32), "zero_point": np.array(32768, dtype=np.uint16)}
|
||||
]
|
||||
elif activation_type == QuantType.QInt16:
|
||||
tensor_quant_overrides[node.output[0]] = [
|
||||
{"scale": np.array(1.0 / 32768.0, dtype=np.float32), "zero_point": np.array(0, dtype=np.int16)}
|
||||
]
|
||||
qnn_compat.process_node(node)
|
||||
|
||||
extra_options = {
|
||||
"MinimumRealRange": 0.0001,
|
||||
"DedicatedQDQPair": False, # Let ORT optimizer duplicate DQ nodes
|
||||
"TensorQuantOverrides": tensor_quant_overrides,
|
||||
"TensorQuantOverrides": overrides_helper.get_dict(),
|
||||
"ActivationSymmetric": activation_symmetric,
|
||||
"WeightSymmetric": weight_symmetric,
|
||||
}
|
||||
|
||||
# TODO: Remove this extra option once ORT uses an ONNX version that supports 16-bit Q/DQ ops.
|
||||
if activation_type in Q16_TYPES or weight_type in Q16_TYPES:
|
||||
overrides_have_int16 = any(t in Q16_TYPES for t in overrides_helper.get_quant_types())
|
||||
if activation_type in Q16_TYPES or weight_type in Q16_TYPES or overrides_have_int16:
|
||||
extra_options["UseQDQContribOps"] = True
|
||||
|
||||
return StaticQuantConfig(
|
||||
|
|
@ -100,3 +114,163 @@ def get_qnn_qdq_config(
|
|||
use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
|
||||
extra_options=extra_options,
|
||||
)
|
||||
|
||||
|
||||
class QnnCompatibilityOverrides:
|
||||
"""
|
||||
Helper that processes nodes to generate quantization overrides that make the resulting QDQ model
|
||||
compatible with QNN EP.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_activation_qtype: QuantType,
|
||||
default_weight_qtype: QuantType,
|
||||
activation_symmetric: bool,
|
||||
weight_symmetric: bool,
|
||||
overrides: TensorQuantOverridesHelper,
|
||||
initializers: dict[str, onnx.TensorProto],
|
||||
):
|
||||
self.default_activation_qtype = default_activation_qtype
|
||||
self.default_weight_qtype = default_weight_qtype
|
||||
self.activation_symmetric = activation_symmetric
|
||||
self.weight_symmetric = weight_symmetric
|
||||
self.overrides = overrides
|
||||
self.initializers = initializers
|
||||
|
||||
self.process_fns = {
|
||||
"MatMul": self._process_matmul,
|
||||
"LayerNormalization": self._process_layernorm,
|
||||
"Sigmoid": self._process_sigmoid,
|
||||
"Tanh": self._process_tanh,
|
||||
}
|
||||
|
||||
def process_node(self, node: onnx.NodeProto):
|
||||
process_fn = self.process_fns.get(node.op_type)
|
||||
|
||||
if process_fn is not None:
|
||||
process_fn(node)
|
||||
|
||||
def _process_matmul(self, node: onnx.NodeProto):
|
||||
"""
|
||||
Overrides MatMul's initializer input(s) to use the default weight type if:
|
||||
- The default weight type is 8-bit
|
||||
- One of the inputs is a 16-bit activation
|
||||
"""
|
||||
assert node.op_type == "MatMul", f"Expected MatMul, but got {node.op_type}"
|
||||
if self.default_weight_qtype not in Q8_TYPES:
|
||||
return
|
||||
|
||||
input_16bit_act = None
|
||||
input_wgt = None
|
||||
|
||||
for input_name in node.input:
|
||||
if input_name and input_name not in self.initializers:
|
||||
qtype = self.overrides.get_node_input_qtype_info(
|
||||
input_name, node.name, self.default_activation_qtype
|
||||
).quant_type
|
||||
if qtype in Q16_TYPES:
|
||||
input_16bit_act = input_name
|
||||
else:
|
||||
input_wgt = input_name
|
||||
|
||||
# Override initializer to use the default weight type.
|
||||
if input_16bit_act and input_wgt:
|
||||
did_update = self.overrides.update_tensor_overrides(
|
||||
input_wgt,
|
||||
{"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric},
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
if not did_update:
|
||||
warn_unable_to_override(node, "quant_type/symmetric", input_wgt, "input weight")
|
||||
|
||||
def _process_layernorm(self, node: onnx.NodeProto):
|
||||
"""
|
||||
Overrides LayerNormalization's initializer input(s), except for bias, to use the default weight type if:
|
||||
- The default weight type is 8-bit
|
||||
- One of the inputs is a 16-bit activation
|
||||
"""
|
||||
assert node.op_type == "LayerNormalization", f"Expected LayerNormalization, but got {node.op_type}"
|
||||
if self.default_weight_qtype not in Q8_TYPES:
|
||||
return
|
||||
|
||||
has_q16_activation = False
|
||||
for input_name in node.input:
|
||||
if input_name and input_name not in self.initializers:
|
||||
qtype = self.overrides.get_node_input_qtype_info(
|
||||
input_name, node.name, self.default_activation_qtype
|
||||
).quant_type
|
||||
if qtype in Q16_TYPES:
|
||||
has_q16_activation = True
|
||||
break
|
||||
|
||||
# Override initializers to use the self.default_weight_qtype. Don't override the bias input.
|
||||
if has_q16_activation:
|
||||
for i in range(2):
|
||||
input_name = node.input[i]
|
||||
if input_name and input_name in self.initializers:
|
||||
did_update = self.overrides.update_tensor_overrides(
|
||||
input_name,
|
||||
{"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric},
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
if not did_update:
|
||||
warn_unable_to_override(node, "quant_type/symmetric", input_name, "input weight")
|
||||
|
||||
def _process_sigmoid(self, node: onnx.NodeProto):
|
||||
"""
|
||||
Overrides 16-bit Sigmoid's output scale and zero-point as per QNN requirements.
|
||||
"""
|
||||
assert node.op_type == "Sigmoid", f"Expected Sigmoid, but got {node.op_type}"
|
||||
output_type = self.overrides.get_node_output_qtype_info(
|
||||
node.output[0], self.default_activation_qtype
|
||||
).quant_type
|
||||
|
||||
if output_type == QuantType.QUInt16:
|
||||
self.overrides.update_tensor_overrides(
|
||||
node.output[0],
|
||||
{
|
||||
"quant_type": output_type,
|
||||
"scale": np.array(1.0 / 65536.0, dtype=np.float32),
|
||||
"zero_point": np.array(0, dtype=np.uint16),
|
||||
},
|
||||
)
|
||||
elif output_type == QuantType.QInt16:
|
||||
self.overrides.update_tensor_overrides(
|
||||
node.output[0],
|
||||
{
|
||||
"quant_type": output_type,
|
||||
"scale": np.array(1.0 / 32768.0, dtype=np.float32),
|
||||
"zero_point": np.array(0, dtype=np.int16),
|
||||
},
|
||||
)
|
||||
|
||||
def _process_tanh(self, node: onnx.NodeProto):
|
||||
"""
|
||||
Overrides 16-bit Tanh's output scale and zero-point as per QNN requirements.
|
||||
"""
|
||||
assert node.op_type == "Tanh", f"Expected Tanh, but got {node.op_type}"
|
||||
output_type = self.overrides.get_node_output_qtype_info(
|
||||
node.output[0], self.default_activation_qtype
|
||||
).quant_type
|
||||
|
||||
if output_type == QuantType.QUInt16:
|
||||
self.overrides.update_tensor_overrides(
|
||||
node.output[0],
|
||||
{
|
||||
"quant_type": output_type,
|
||||
"scale": np.array(1.0 / 32768.0, dtype=np.float32),
|
||||
"zero_point": np.array(32768, dtype=np.uint16),
|
||||
},
|
||||
)
|
||||
elif output_type == QuantType.QInt16:
|
||||
self.overrides.update_tensor_overrides(
|
||||
node.output[0],
|
||||
{
|
||||
"quant_type": output_type,
|
||||
"scale": np.array(1.0 / 32768.0, dtype=np.float32),
|
||||
"zero_point": np.array(0, dtype=np.int16),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,52 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
from collections.abc import MutableMapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .quant_utils import QuantType
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantTypeInfo:
|
||||
"""
|
||||
The quantization type information for a tensor override.
|
||||
"""
|
||||
|
||||
quant_type: QuantType
|
||||
symmetric: bool | None = None # If None, assumes default is used.
|
||||
reduce_range: bool | None = None # If None, assumes default is used.
|
||||
|
||||
def __eq__(self, other: object):
|
||||
if isinstance(other, QuantTypeInfo):
|
||||
return (
|
||||
self.quant_type == other.quant_type
|
||||
and (self.symmetric is None or other.symmetric is None or self.symmetric == other.symmetric)
|
||||
and (self.reduce_range is None or other.reduce_range is None or self.reduce_range == other.reduce_range)
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
def load_from_dict(
|
||||
raw_dict: dict[str, Any],
|
||||
default_activation_qtype: QuantType | None = None,
|
||||
default_activation_symmetric: bool | None = None,
|
||||
default_activation_reduce_range: bool | None = None,
|
||||
) -> QuantTypeInfo:
|
||||
return QuantTypeInfo(
|
||||
raw_dict.get("quant_type", default_activation_qtype),
|
||||
raw_dict.get("symmetric", default_activation_symmetric),
|
||||
raw_dict.get("reduce_range", default_activation_reduce_range),
|
||||
)
|
||||
|
||||
def save_to_dict(self, raw_dict: dict[str, Any]):
|
||||
raw_dict["quant_type"] = self.quant_type
|
||||
if self.symmetric is not None:
|
||||
raw_dict["symmetric"] = self.symmetric
|
||||
if self.reduce_range is not None:
|
||||
raw_dict["reduce_range"] = self.reduce_range
|
||||
|
||||
|
||||
class TensorQuantOverridesHelper(MutableMapping):
|
||||
"""
|
||||
Utility wrapper over the tensor quantization overrides passed via extra_options.
|
||||
|
|
@ -184,9 +225,99 @@ class TensorQuantOverridesHelper(MutableMapping):
|
|||
|
||||
return True, None
|
||||
|
||||
def update_tensor_overrides(
|
||||
self,
|
||||
tensor_name: str,
|
||||
new_vals: dict[str, Any],
|
||||
channels: list[int] | None = None,
|
||||
overwrite: bool = True,
|
||||
) -> bool:
|
||||
if not new_vals:
|
||||
return False
|
||||
|
||||
channels = set(channels) if channels is not None else None
|
||||
have_overrides = self.overrides.get(tensor_name)
|
||||
|
||||
# If `overwrite` is False, check if we would overwrite anything.
|
||||
do_update = True
|
||||
if not overwrite and have_overrides:
|
||||
for channel, overrides in enumerate(self.overrides[tensor_name]):
|
||||
if channels is not None and channel not in channels:
|
||||
continue
|
||||
if set(new_vals).intersection(set(overrides)):
|
||||
do_update = False
|
||||
break
|
||||
|
||||
# Do the update if `overwrite` is True or if nothing is overwritten (do not want partial overwrites).
|
||||
if do_update:
|
||||
if not have_overrides:
|
||||
self.overrides[tensor_name] = [{}]
|
||||
|
||||
for channel, overrides in enumerate(self.overrides[tensor_name]):
|
||||
if channels is not None and channel not in channels:
|
||||
continue
|
||||
overrides.update(new_vals)
|
||||
|
||||
return do_update
|
||||
|
||||
def get_node_output_qtype_info(
|
||||
self,
|
||||
output_name: str,
|
||||
default_qtype: QuantType | None,
|
||||
default_symmetric: bool | None = None,
|
||||
) -> QuantTypeInfo:
|
||||
if output_name not in self.overrides:
|
||||
return QuantTypeInfo(default_qtype, default_symmetric)
|
||||
|
||||
# Get the first overrides dict in the list. This works for both per-tensor and per-channel
|
||||
# quantization because all channels must use the same quant type.
|
||||
tensor_overrides = self.overrides[output_name][0]
|
||||
|
||||
return QuantTypeInfo(
|
||||
tensor_overrides.get("quant_type", default_qtype),
|
||||
tensor_overrides.get("symmetric", default_symmetric),
|
||||
)
|
||||
|
||||
def get_node_input_qtype_info(
|
||||
self,
|
||||
input_name: str,
|
||||
node_name: str,
|
||||
default_qtype: QuantType | None,
|
||||
default_symmetric: bool | None = None,
|
||||
default_reduce_range: bool | None = None,
|
||||
) -> QuantTypeInfo:
|
||||
if input_name not in self.overrides or not self.overrides[input_name]:
|
||||
return QuantTypeInfo(default_qtype, default_symmetric, default_reduce_range)
|
||||
|
||||
# Get the first overrides dict in the list. This works for both per-tensor and per-channel
|
||||
# quantization because all channels must use the same quant type.
|
||||
tensor_overrides = self.overrides[input_name][0]
|
||||
producer_type = tensor_overrides.get("quant_type", default_qtype)
|
||||
|
||||
if "convert" not in tensor_overrides:
|
||||
return QuantTypeInfo(producer_type, default_symmetric, default_reduce_range)
|
||||
|
||||
# This tensor is converted. Check if the node gets the original qtype or the converted qtype.
|
||||
convert_dict = tensor_overrides["convert"]
|
||||
qtype_info = QuantTypeInfo(
|
||||
producer_type,
|
||||
convert_dict.get("symmetric", default_symmetric),
|
||||
convert_dict.get("reduce_range", default_reduce_range),
|
||||
)
|
||||
|
||||
# Check if all nodes receive the converted type (i.e., recv_nodes is None) or this node
|
||||
# is in the list of consumers (recv_nodes).
|
||||
if ("recv_nodes" not in convert_dict) or (node_name in convert_dict["recv_nodes"]):
|
||||
qtype_info.quant_type = convert_dict["quant_type"]
|
||||
|
||||
return qtype_info
|
||||
|
||||
def pprint_str(self, indent=None) -> str:
|
||||
return json.dumps(self.overrides, default=str, indent=indent)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return not self.overrides
|
||||
|
||||
def get_dict(self) -> dict[str, list[dict[str, Any]]]:
|
||||
return self.overrides
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,171 @@
|
|||
#!/usr/bin/env python
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
|
||||
import onnx
|
||||
|
||||
from onnxruntime.quantization import QuantType
|
||||
from onnxruntime.quantization.execution_providers.qnn.mixed_precision_overrides_utils import (
|
||||
MixedPrecisionTensorQuantOverridesFixer,
|
||||
)
|
||||
from onnxruntime.quantization.tensor_quant_overrides import TensorQuantOverridesHelper
|
||||
|
||||
|
||||
class TestMixedPrecisionQuantOverridesFixer(unittest.TestCase):
|
||||
def build_test_model_1(self, shape):
|
||||
input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, shape)
|
||||
input_1 = onnx.helper.make_tensor_value_info("input_1", onnx.TensorProto.FLOAT, shape)
|
||||
output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, shape)
|
||||
output_1 = onnx.helper.make_tensor_value_info("output_1", onnx.TensorProto.FLOAT, shape)
|
||||
output_2 = onnx.helper.make_tensor_value_info("output_2", onnx.TensorProto.FLOAT, shape)
|
||||
|
||||
op1_node = onnx.helper.make_node("Sigmoid", ["input_0"], ["op1_out"], name="op1")
|
||||
op2_node = onnx.helper.make_node("Cos", ["input_1"], ["op2_out"], name="op2")
|
||||
op3_node = onnx.helper.make_node("Sin", ["op1_out"], ["op3_out"], name="op3")
|
||||
op4_node = onnx.helper.make_node("Tanh", ["op2_out"], ["op4_out"], name="op4")
|
||||
op5_node = onnx.helper.make_node("Mul", ["op3_out", "op4_out"], ["op5_out"], name="op5")
|
||||
op6_node = onnx.helper.make_node("Relu", ["op5_out"], ["output_0"], name="op6")
|
||||
op7_node = onnx.helper.make_node("Cos", ["op2_out"], ["output_1"], name="op7")
|
||||
op8_node = onnx.helper.make_node("Sigmoid", ["op2_out"], ["output_2"], name="op8")
|
||||
|
||||
graph = onnx.helper.make_graph(
|
||||
[
|
||||
op1_node,
|
||||
op2_node,
|
||||
op3_node,
|
||||
op4_node,
|
||||
op5_node,
|
||||
op6_node,
|
||||
op7_node,
|
||||
op8_node,
|
||||
],
|
||||
"mixed_prec_test",
|
||||
[input_0, input_1],
|
||||
[output_0, output_1, output_2],
|
||||
)
|
||||
opset_imports = [
|
||||
onnx.helper.make_opsetid("", 18),
|
||||
]
|
||||
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
|
||||
return onnx.shape_inference.infer_shapes(model)
|
||||
|
||||
def test_fixer_1(self):
|
||||
shape = (1, 2, 3)
|
||||
model = self.build_test_model_1(shape)
|
||||
onnx.save_model(model, "model.onnx")
|
||||
|
||||
default_act_qtype = QuantType.QUInt8
|
||||
raw_overrides = {"op4_out": [{"quant_type": QuantType.QUInt16}]}
|
||||
overrides = TensorQuantOverridesHelper(raw_overrides)
|
||||
fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype)
|
||||
fixer.apply(default_act_qtype, default_activation_symmetric=False)
|
||||
|
||||
expected = {
|
||||
"op2_out": [
|
||||
{"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}}
|
||||
],
|
||||
"op3_out": [
|
||||
{"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}}
|
||||
],
|
||||
"op4_out": [{"quant_type": QuantType.QUInt16}],
|
||||
"op5_out": [
|
||||
{"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"op6"}}}
|
||||
],
|
||||
}
|
||||
self.assertDictEqual(overrides.get_dict(), expected)
|
||||
|
||||
def test_fixer_with_symmetric(self):
|
||||
shape = (1, 2, 3)
|
||||
model = self.build_test_model_1(shape)
|
||||
onnx.save_model(model, "model.onnx")
|
||||
|
||||
default_act_qtype = QuantType.QInt8
|
||||
raw_overrides = {"op4_out": [{"quant_type": QuantType.QInt16, "symmetric": True}]}
|
||||
overrides = TensorQuantOverridesHelper(raw_overrides)
|
||||
fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype)
|
||||
fixer.apply(default_act_qtype, default_activation_symmetric=False)
|
||||
|
||||
expected = {
|
||||
"op2_out": [
|
||||
{
|
||||
"quant_type": QuantType.QInt8,
|
||||
"convert": {"quant_type": QuantType.QInt16, "symmetric": True, "recv_nodes": {"op4"}},
|
||||
}
|
||||
],
|
||||
"op3_out": [
|
||||
{
|
||||
"quant_type": QuantType.QInt8,
|
||||
"convert": {"quant_type": QuantType.QInt16, "symmetric": True, "recv_nodes": {"op5"}},
|
||||
}
|
||||
],
|
||||
"op4_out": [{"quant_type": QuantType.QInt16, "symmetric": True}],
|
||||
"op5_out": [
|
||||
{
|
||||
"quant_type": QuantType.QInt16,
|
||||
"symmetric": True,
|
||||
"convert": {"quant_type": QuantType.QInt8, "recv_nodes": {"op6"}},
|
||||
}
|
||||
],
|
||||
}
|
||||
self.assertDictEqual(overrides.get_dict(), expected)
|
||||
|
||||
def test_fixer_upgrade_output(self):
|
||||
shape = (1, 2, 3)
|
||||
model = self.build_test_model_1(shape)
|
||||
onnx.save_model(model, "model.onnx")
|
||||
|
||||
default_act_qtype = QuantType.QUInt8
|
||||
raw_overrides = {
|
||||
"op4_out": [{"quant_type": QuantType.QUInt16}],
|
||||
"output_0": [{"quant_type": QuantType.QUInt16}],
|
||||
}
|
||||
overrides = TensorQuantOverridesHelper(raw_overrides)
|
||||
fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype)
|
||||
fixer.apply(default_act_qtype, default_activation_symmetric=False)
|
||||
|
||||
expected = {
|
||||
"op2_out": [
|
||||
{"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}}
|
||||
],
|
||||
"op3_out": [
|
||||
{"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}}
|
||||
],
|
||||
"op4_out": [{"quant_type": QuantType.QUInt16}],
|
||||
"op5_out": [{"quant_type": QuantType.QUInt16}],
|
||||
"output_0": [{"quant_type": QuantType.QUInt16}],
|
||||
}
|
||||
self.assertDictEqual(overrides.get_dict(), expected)
|
||||
|
||||
def test_fixer_upgrade_input(self):
|
||||
shape = (1, 2, 3)
|
||||
model = self.build_test_model_1(shape)
|
||||
onnx.save_model(model, "model.onnx")
|
||||
|
||||
default_act_qtype = QuantType.QUInt8
|
||||
raw_overrides = {"op4_out": [{"quant_type": QuantType.QUInt16}], "input_0": [{"quant_type": QuantType.QUInt16}]}
|
||||
overrides = TensorQuantOverridesHelper(raw_overrides)
|
||||
fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype)
|
||||
fixer.apply(default_act_qtype, default_activation_symmetric=False)
|
||||
|
||||
expected = {
|
||||
"input_0": [{"quant_type": QuantType.QUInt16}],
|
||||
"op1_out": [
|
||||
{"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"op3"}}}
|
||||
],
|
||||
"op2_out": [
|
||||
{"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}}
|
||||
],
|
||||
"op3_out": [
|
||||
{"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}}
|
||||
],
|
||||
"op4_out": [{"quant_type": QuantType.QUInt16}],
|
||||
"op5_out": [
|
||||
{"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"op6"}}}
|
||||
],
|
||||
}
|
||||
self.assertDictEqual(overrides.get_dict(), expected)
|
||||
|
|
@ -11,12 +11,12 @@ import unittest
|
|||
import numpy as np
|
||||
import onnx
|
||||
|
||||
from onnxruntime import quantization
|
||||
from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantType, quantize_static
|
||||
from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config
|
||||
from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain
|
||||
|
||||
|
||||
class DummyDataReader(quantization.CalibrationDataReader):
|
||||
class DummyDataReader(CalibrationDataReader):
|
||||
def __init__(self, activations):
|
||||
self.iterator = ({"INP": act} for act in activations)
|
||||
|
||||
|
|
@ -81,11 +81,11 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
|
|||
if activation_type is None:
|
||||
activation_type = self.default_act_qtype
|
||||
|
||||
quantization.quantize_static(
|
||||
quantize_static(
|
||||
model_input="model.onnx",
|
||||
model_output=output_model_name,
|
||||
calibration_data_reader=DummyDataReader(self.activations),
|
||||
quant_format=quantization.QuantFormat.QDQ,
|
||||
quant_format=QuantFormat.QDQ,
|
||||
activation_type=activation_type,
|
||||
weight_type=self.default_wgt_qtype,
|
||||
per_channel=per_channel,
|
||||
|
|
@ -223,8 +223,8 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
|
|||
"SIG_OUT": [
|
||||
{"scale": np.array(1.0, dtype=np.float32), "zero_point": np.array(127, dtype=np.uint8)}
|
||||
],
|
||||
"WGT": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}],
|
||||
"BIAS": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}],
|
||||
"WGT": [{"quant_type": QuantType.QInt8, "symmetric": True, "reduce_range": True}],
|
||||
"BIAS": [{"quant_type": QuantType.QInt8, "symmetric": True, "reduce_range": True}],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
|
@ -240,7 +240,7 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
|
|||
self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0))
|
||||
|
||||
# Weight should have different type, zero_point, and scale
|
||||
self.assertEqual(wgt_zp.data_type, quantization.QuantType.QInt8.tensor_type)
|
||||
self.assertEqual(wgt_zp.data_type, QuantType.QInt8.tensor_type)
|
||||
|
||||
wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type, reduce_range=True, symmetric=True)
|
||||
wgt_rmin, wgt_rmax = np.min(self.weight), np.max(self.weight)
|
||||
|
|
@ -249,7 +249,7 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
|
|||
self.assertEqual(wgt_sc.float_data[0], np.float32(new_wgt_sc))
|
||||
|
||||
# Bias should now be treated as a weight and should have different type, zero_point, and scale
|
||||
self.assertEqual(bias_zp.data_type, quantization.QuantType.QInt8.tensor_type)
|
||||
self.assertEqual(bias_zp.data_type, QuantType.QInt8.tensor_type)
|
||||
|
||||
bias_qmin, bias_qmax = get_qmin_qmax_for_qType(bias_zp.data_type, reduce_range=True, symmetric=True)
|
||||
bias_rmin, bias_rmax = np.min(self.bias), np.max(self.bias)
|
||||
|
|
@ -375,7 +375,7 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
|
|||
"""
|
||||
rmin_vals = [0.0, 0.2]
|
||||
rmax_vals = [1.0, 0.8]
|
||||
quant_type = quantization.QuantType.QUInt8
|
||||
quant_type = QuantType.QUInt8
|
||||
reduce_ranges = [True, False]
|
||||
(
|
||||
_,
|
||||
|
|
@ -434,8 +434,8 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
|
|||
activation_type=onnx.TensorProto.UINT8, # Default to 8bit activations
|
||||
extra_options={
|
||||
"TensorQuantOverrides": {
|
||||
"INP": [{"quant_type": quantization.QuantType.QUInt16}],
|
||||
"SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}],
|
||||
"INP": [{"quant_type": QuantType.QUInt16}],
|
||||
"SIG_OUT": [{"quant_type": QuantType.QUInt16}],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
|
@ -559,31 +559,446 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
|
|||
|
||||
self.assertIn("option 'reduce_range' is invalid with 'scale' and 'zero_point'", str(context.exception))
|
||||
|
||||
def test_get_qnn_qdq_config(self):
|
||||
def test_get_qnn_qdq_config_sigmoid(self):
|
||||
"""
|
||||
Test that the QNN-specific configs override the scale and zero-point of Sigmoid.
|
||||
Test that the QNN-specific configs override the scale and zero-point of 16-bit Sigmoid.
|
||||
"""
|
||||
self.build_float32_model()
|
||||
|
||||
qnn_config = get_qnn_qdq_config(
|
||||
"model.onnx", DummyDataReader(self.activations), activation_type=quantization.QuantType.QUInt16
|
||||
# Create float model with a Abs --> Sigmoid
|
||||
graph = onnx.helper.make_graph(
|
||||
[
|
||||
onnx.helper.make_node("Abs", ["input_0"], ["abs_out"], name="Abs_0"),
|
||||
onnx.helper.make_node("Sigmoid", ["abs_out"], ["output_0"], name="Sigmoid_0"),
|
||||
],
|
||||
"sigmoid_graph",
|
||||
[onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (1, 2, 3))],
|
||||
[onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (1, 2, 3))],
|
||||
)
|
||||
opset_imports = [
|
||||
onnx.helper.make_opsetid("", 18),
|
||||
]
|
||||
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
float_model_path = "model.onnx"
|
||||
onnx.save_model(model, float_model_path)
|
||||
|
||||
self.assertEqual(qnn_config.extra_options["MinimumRealRange"], 0.0001)
|
||||
other_override_0 = {"abs_out": [{"symmetric": True}]}
|
||||
other_override_1 = {
|
||||
"abs_out": [
|
||||
{
|
||||
"quant_type": QuantType.QUInt8,
|
||||
"convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"Sigmoid_0"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
other_override_2 = {
|
||||
"abs_out": [
|
||||
{
|
||||
"quant_type": QuantType.QInt8,
|
||||
"convert": {"quant_type": QuantType.QInt16, "recv_nodes": {"Sigmoid_0"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
inp_zp, inp_sc, sig_out_zp, sig_out_sc, _, _, _, _, _, _ = self.perform_qdq_quantization(
|
||||
"model_qnn_quant_overrides.onnx",
|
||||
extra_options=qnn_config.extra_options,
|
||||
activation_type=quantization.QuantType.QUInt16,
|
||||
# Enumerate subtests (default_act_qtype, sigmoid_out_qtype, other_override)
|
||||
subtest_configs = [
|
||||
(QuantType.QUInt16, None, {}), # Sigmoid gets new scale/zp
|
||||
(QuantType.QUInt16, None, other_override_0), # Sigmoid gets new scale/zp
|
||||
(QuantType.QInt16, None, {}), # Sigmoid gets new scale/zp
|
||||
(QuantType.QInt16, None, other_override_0), # Sigmoid gets new scale/zp
|
||||
(QuantType.QUInt8, QuantType.QUInt16, other_override_1), # Sigmoid gets new scale/zp
|
||||
(QuantType.QInt8, QuantType.QInt16, other_override_2), # Sigmoid gets new scale/zp
|
||||
(QuantType.QUInt8, None, other_override_0), # Sigmoid DOES NOT gets new scale/zp
|
||||
(QuantType.QInt8, None, {}), # Sigmoid DOES NOT gets new scale/zp
|
||||
(QuantType.QInt8, QuantType.QInt8, {}), # Sigmoid DOES NOT gets new scale/zp
|
||||
]
|
||||
|
||||
# Test that Sigmoid's output scale and zp should be overridden for 16-bit Sigmoid.
|
||||
for default_act_qtype, sigmoid_out_qtype, abs_override in subtest_configs:
|
||||
with self.subTest(
|
||||
default_act_qtype=default_act_qtype, sigmoid_out_qtype=sigmoid_out_qtype, abs_override=abs_override
|
||||
):
|
||||
init_overrides = {}
|
||||
init_overrides.update(abs_override)
|
||||
|
||||
if sigmoid_out_qtype is not None:
|
||||
init_overrides["output_0"] = [{"quant_type": sigmoid_out_qtype}]
|
||||
|
||||
qnn_config = get_qnn_qdq_config(
|
||||
float_model_path,
|
||||
DummyDataReader([]),
|
||||
activation_type=default_act_qtype,
|
||||
init_overrides=(init_overrides if init_overrides else None),
|
||||
add_qtype_converts=False,
|
||||
)
|
||||
|
||||
self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "Sigmoid"})
|
||||
|
||||
if default_act_qtype == QuantType.QUInt16 or sigmoid_out_qtype == QuantType.QUInt16:
|
||||
self.assertIn("TensorQuantOverrides", qnn_config.extra_options)
|
||||
self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"])
|
||||
self.assertEqual(
|
||||
qnn_config.extra_options["TensorQuantOverrides"]["output_0"],
|
||||
[
|
||||
{
|
||||
"quant_type": QuantType.QUInt16,
|
||||
"scale": np.array(1.0 / 65536.0, dtype=np.float32),
|
||||
"zero_point": np.array(0, dtype=np.uint16),
|
||||
}
|
||||
],
|
||||
)
|
||||
elif default_act_qtype == QuantType.QInt16 or sigmoid_out_qtype == QuantType.QInt16:
|
||||
self.assertIn("TensorQuantOverrides", qnn_config.extra_options)
|
||||
self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"])
|
||||
self.assertEqual(
|
||||
qnn_config.extra_options["TensorQuantOverrides"]["output_0"],
|
||||
[
|
||||
{
|
||||
"quant_type": QuantType.QInt16,
|
||||
"scale": np.array(1.0 / 32768.0, dtype=np.float32),
|
||||
"zero_point": np.array(0, dtype=np.int16),
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
def test_get_qnn_qdq_config_tanh(self):
|
||||
"""
|
||||
Test that the QNN-specific configs override the scale and zero-point of 16-bit Tanh.
|
||||
"""
|
||||
|
||||
# Create float model with a Abs --> Tanh
|
||||
graph = onnx.helper.make_graph(
|
||||
[
|
||||
onnx.helper.make_node("Abs", ["input_0"], ["abs_out"], name="Abs_0"),
|
||||
onnx.helper.make_node("Tanh", ["abs_out"], ["output_0"], name="Tanh_0"),
|
||||
],
|
||||
"tanh_graph",
|
||||
[onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (1, 2, 3))],
|
||||
[onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (1, 2, 3))],
|
||||
)
|
||||
opset_imports = [
|
||||
onnx.helper.make_opsetid("", 18),
|
||||
]
|
||||
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
float_model_path = "model.onnx"
|
||||
onnx.save_model(model, float_model_path)
|
||||
|
||||
# Input should have uint16 quant type
|
||||
self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16)
|
||||
other_override_0 = {"abs_out": [{"symmetric": True}]}
|
||||
other_override_1 = {
|
||||
"abs_out": [
|
||||
{"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"Tanh_0"}}}
|
||||
]
|
||||
}
|
||||
other_override_2 = {
|
||||
"abs_out": [
|
||||
{"quant_type": QuantType.QInt8, "convert": {"quant_type": QuantType.QInt16, "recv_nodes": {"Tanh_0"}}}
|
||||
]
|
||||
}
|
||||
|
||||
# Sigmoid output should have overridden scale/zp
|
||||
self.assertEqual(sig_out_zp.int32_data[0], 0)
|
||||
self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16)
|
||||
self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0 / 65536.0))
|
||||
# Enumerate subtests (default_act_qtype, tanh_out_qtype, other_override)
|
||||
subtest_configs = [
|
||||
(QuantType.QUInt16, None, {}), # Tanh gets new scale/zp
|
||||
(QuantType.QUInt16, None, other_override_0), # Tanh gets new scale/zp
|
||||
(QuantType.QInt16, None, {}), # Tanh gets new scale/zp
|
||||
(QuantType.QInt16, None, other_override_0), # Tanh gets new scale/zp
|
||||
(QuantType.QUInt8, QuantType.QUInt16, other_override_1), # Tanh gets new scale/zp
|
||||
(QuantType.QInt8, QuantType.QInt16, other_override_2), # Tanh gets new scale/zp
|
||||
(QuantType.QUInt8, None, other_override_0), # Tanh DOES NOT gets new scale/zp
|
||||
(QuantType.QInt8, None, {}), # Tanh DOES NOT gets new scale/zp
|
||||
(QuantType.QInt8, QuantType.QInt8, {}), # Tanh DOES NOT gets new scale/zp
|
||||
]
|
||||
|
||||
# Test that Tanh's output scale and zp should be overridden for 16-bit Tanh.
|
||||
for default_act_qtype, tanh_out_qtype, abs_override in subtest_configs:
|
||||
with self.subTest(
|
||||
default_act_qtype=default_act_qtype, tanh_out_qtype=tanh_out_qtype, abs_override=abs_override
|
||||
):
|
||||
init_overrides = {}
|
||||
init_overrides.update(abs_override)
|
||||
|
||||
if tanh_out_qtype is not None:
|
||||
init_overrides["output_0"] = [{"quant_type": tanh_out_qtype}]
|
||||
|
||||
qnn_config = get_qnn_qdq_config(
|
||||
float_model_path,
|
||||
DummyDataReader([]),
|
||||
activation_type=default_act_qtype,
|
||||
init_overrides=(init_overrides if init_overrides else None),
|
||||
add_qtype_converts=False,
|
||||
)
|
||||
|
||||
self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "Tanh"})
|
||||
|
||||
if default_act_qtype == QuantType.QUInt16 or tanh_out_qtype == QuantType.QUInt16:
|
||||
self.assertIn("TensorQuantOverrides", qnn_config.extra_options)
|
||||
self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"])
|
||||
self.assertEqual(
|
||||
qnn_config.extra_options["TensorQuantOverrides"]["output_0"],
|
||||
[
|
||||
{
|
||||
"quant_type": QuantType.QUInt16,
|
||||
"scale": np.array(1.0 / 32768.0, dtype=np.float32),
|
||||
"zero_point": np.array(32768, dtype=np.uint16),
|
||||
}
|
||||
],
|
||||
)
|
||||
elif default_act_qtype == QuantType.QInt16 or tanh_out_qtype == QuantType.QInt16:
|
||||
self.assertIn("TensorQuantOverrides", qnn_config.extra_options)
|
||||
self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"])
|
||||
self.assertEqual(
|
||||
qnn_config.extra_options["TensorQuantOverrides"]["output_0"],
|
||||
[
|
||||
{
|
||||
"quant_type": QuantType.QInt16,
|
||||
"scale": np.array(1.0 / 32768.0, dtype=np.float32),
|
||||
"zero_point": np.array(0, dtype=np.int16),
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
def test_get_qnn_qdq_config_matmul(self):
|
||||
"""
|
||||
Test that the QNN-specific configs override MatMul's initializer input type to 8-bit if
|
||||
the other input is 16-bit and the default weight type is 8-bit.
|
||||
"""
|
||||
# Create float model with a Abs --> MatMul
|
||||
graph = onnx.helper.make_graph(
|
||||
[
|
||||
onnx.helper.make_node("Abs", ["input_0"], ["abs_0_out"], name="Abs_0"),
|
||||
onnx.helper.make_node("MatMul", ["abs_0_out", "weight"], ["matmul_0_out"], name="MatMul_0"),
|
||||
onnx.helper.make_node("Abs", ["matmul_0_out"], ["output_0"], name="Abs_1"),
|
||||
],
|
||||
"matmul_graph",
|
||||
[onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (2, 3))],
|
||||
[onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (2, 2))],
|
||||
initializer=[onnx.numpy_helper.from_array(np.random.random((3, 2)).astype(np.float32), "weight")],
|
||||
)
|
||||
opset_imports = [
|
||||
onnx.helper.make_opsetid("", 18),
|
||||
]
|
||||
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
float_model_path = "model.onnx"
|
||||
onnx.save_model(model, float_model_path)
|
||||
|
||||
q16_qtypes = {QuantType.QUInt16, QuantType.QInt16}
|
||||
q8_qtypes = {QuantType.QUInt8, QuantType.QInt8}
|
||||
symmetric_wgt_qtypes = {QuantType.QInt8, QuantType.QInt16}
|
||||
|
||||
other_override_0 = {"output_0": [{"symmetric": True}]}
|
||||
other_override_1 = {
|
||||
"matmul_0_out": [
|
||||
{
|
||||
"quant_type": QuantType.QUInt16,
|
||||
"convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"Abs_1"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
other_override_2 = {
|
||||
"matmul_0_out": [
|
||||
{
|
||||
"quant_type": QuantType.QInt16,
|
||||
"convert": {"quant_type": QuantType.QInt8, "recv_nodes": {"Abs_1"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
convert_matmul_input = {
|
||||
"abs_0_out": [
|
||||
{
|
||||
"quant_type": QuantType.QUInt8,
|
||||
"convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"MatMul_0"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Enumerate subtests (default_act_qtype, default_wgt_qtype, matmul_in_qtype, other_override)
|
||||
subtest_configs = [
|
||||
(QuantType.QUInt8, QuantType.QUInt8, None, {}),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, {}),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_0),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_1),
|
||||
(QuantType.QInt8, QuantType.QInt8, QuantType.QInt16, other_override_2),
|
||||
(QuantType.QUInt16, QuantType.QUInt8, None, other_override_0),
|
||||
(QuantType.QInt16, QuantType.QInt8, None, {}),
|
||||
(QuantType.QUInt16, QuantType.QUInt16, None, other_override_0),
|
||||
(QuantType.QInt16, QuantType.QInt16, None, {}),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, None, {}),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, None, convert_matmul_input),
|
||||
]
|
||||
|
||||
# Test if MatMul's weight input is overridden.
|
||||
for default_act_qtype, default_wgt_qtype, matmul_input_qtype, other_override in subtest_configs:
|
||||
with self.subTest(
|
||||
default_act_qtype=default_act_qtype,
|
||||
default_wgt_qtype=default_wgt_qtype,
|
||||
matmul_input_qtype=matmul_input_qtype,
|
||||
other_override=other_override,
|
||||
):
|
||||
init_overrides = {}
|
||||
init_overrides.update(other_override)
|
||||
|
||||
if matmul_input_qtype is not None:
|
||||
init_overrides["abs_0_out"] = [{"quant_type": matmul_input_qtype}]
|
||||
|
||||
qnn_config = get_qnn_qdq_config(
|
||||
float_model_path,
|
||||
DummyDataReader([]),
|
||||
activation_type=default_act_qtype,
|
||||
weight_type=default_wgt_qtype,
|
||||
init_overrides=(init_overrides if init_overrides else None),
|
||||
add_qtype_converts=False,
|
||||
)
|
||||
|
||||
self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "MatMul"})
|
||||
input_is_16bit = (
|
||||
(default_act_qtype in q16_qtypes)
|
||||
or (matmul_input_qtype in q16_qtypes)
|
||||
or (other_override == convert_matmul_input)
|
||||
)
|
||||
weight_is_symmetric = default_wgt_qtype in symmetric_wgt_qtypes
|
||||
|
||||
if input_is_16bit and default_wgt_qtype in q8_qtypes:
|
||||
self.assertIn("TensorQuantOverrides", qnn_config.extra_options)
|
||||
self.assertIn("weight", qnn_config.extra_options["TensorQuantOverrides"])
|
||||
self.assertEqual(
|
||||
qnn_config.extra_options["TensorQuantOverrides"]["weight"],
|
||||
[
|
||||
{
|
||||
"quant_type": default_wgt_qtype,
|
||||
"symmetric": weight_is_symmetric,
|
||||
}
|
||||
],
|
||||
)
|
||||
elif init_overrides:
|
||||
self.assertIn("TensorQuantOverrides", qnn_config.extra_options)
|
||||
self.assertNotIn("weight", qnn_config.extra_options["TensorQuantOverrides"])
|
||||
|
||||
self.assertEqual(weight_is_symmetric, qnn_config.extra_options["WeightSymmetric"])
|
||||
|
||||
def test_get_qnn_qdq_config_layernorm(self):
|
||||
"""
|
||||
Test that the QNN-specific configs override LayerNorm's initializer input type to 8-bit if
|
||||
the other input is 16-bit and the default weight type is 8-bit.
|
||||
"""
|
||||
# Create float model with a Abs --> LayerNormalization
|
||||
graph = onnx.helper.make_graph(
|
||||
[
|
||||
onnx.helper.make_node("Abs", ["input_0"], ["abs_0_out"], name="Abs_0"),
|
||||
onnx.helper.make_node(
|
||||
"LayerNormalization", ["abs_0_out", "weight", "bias"], ["layernorm_0_out"], name="LayerNorm_0"
|
||||
),
|
||||
onnx.helper.make_node("Abs", ["layernorm_0_out"], ["output_0"], name="Abs_1"),
|
||||
],
|
||||
"layernorm_graph",
|
||||
[onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (2, 3))],
|
||||
[onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (2, 3))],
|
||||
initializer=[
|
||||
onnx.numpy_helper.from_array(np.random.random((2, 3)).astype(np.float32), "weight"),
|
||||
onnx.numpy_helper.from_array(np.random.random((2, 3)).astype(np.float32), "bias"),
|
||||
],
|
||||
)
|
||||
opset_imports = [
|
||||
onnx.helper.make_opsetid("", 18),
|
||||
]
|
||||
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
float_model_path = "model.onnx"
|
||||
onnx.save_model(model, float_model_path)
|
||||
|
||||
q16_qtypes = {QuantType.QUInt16, QuantType.QInt16}
|
||||
q8_qtypes = {QuantType.QUInt8, QuantType.QInt8}
|
||||
symmetric_wgt_qtypes = {QuantType.QInt8, QuantType.QInt16}
|
||||
|
||||
other_override_0 = {"output_0": [{"symmetric": True}]}
|
||||
other_override_1 = {
|
||||
"layernorm_0_out": [
|
||||
{
|
||||
"quant_type": QuantType.QUInt16,
|
||||
"convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"Abs_1"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
other_override_2 = {
|
||||
"layernorm_0_out": [
|
||||
{
|
||||
"quant_type": QuantType.QInt16,
|
||||
"convert": {"quant_type": QuantType.QInt8, "recv_nodes": {"Abs_1"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
convert_layernorm_input = {
|
||||
"abs_0_out": [
|
||||
{
|
||||
"quant_type": QuantType.QUInt8,
|
||||
"convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"LayerNorm_0"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Enumerate subtests (default_act_qtype, default_wgt_qtype, layernorm_in_qtype, other_override)
|
||||
subtest_configs = [
|
||||
(QuantType.QUInt8, QuantType.QUInt8, None, {}),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, {}),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_0),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_1),
|
||||
(QuantType.QInt8, QuantType.QInt8, QuantType.QInt16, other_override_2),
|
||||
(QuantType.QUInt16, QuantType.QUInt8, None, other_override_0),
|
||||
(QuantType.QInt16, QuantType.QInt8, None, {}),
|
||||
(QuantType.QUInt16, QuantType.QUInt16, None, other_override_0),
|
||||
(QuantType.QInt16, QuantType.QInt16, None, {}),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, None, {}),
|
||||
(QuantType.QUInt8, QuantType.QUInt8, None, convert_layernorm_input),
|
||||
]
|
||||
|
||||
# Test if LayerNorm's weight input is overridden.
|
||||
for default_act_qtype, default_wgt_qtype, layernorm_input_qtype, other_override in subtest_configs:
|
||||
with self.subTest(
|
||||
default_act_qtype=default_act_qtype,
|
||||
default_wgt_qtype=default_wgt_qtype,
|
||||
layernorm_input_qtype=layernorm_input_qtype,
|
||||
other_override=other_override,
|
||||
):
|
||||
init_overrides = {}
|
||||
init_overrides.update(other_override)
|
||||
|
||||
if layernorm_input_qtype is not None:
|
||||
init_overrides["abs_0_out"] = [{"quant_type": layernorm_input_qtype}]
|
||||
|
||||
qnn_config = get_qnn_qdq_config(
|
||||
float_model_path,
|
||||
DummyDataReader([]),
|
||||
activation_type=default_act_qtype,
|
||||
weight_type=default_wgt_qtype,
|
||||
init_overrides=(init_overrides if init_overrides else None),
|
||||
add_qtype_converts=False,
|
||||
)
|
||||
|
||||
self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "LayerNormalization"})
|
||||
input_is_16bit = (
|
||||
(default_act_qtype in q16_qtypes)
|
||||
or (layernorm_input_qtype in q16_qtypes)
|
||||
or (other_override == convert_layernorm_input)
|
||||
)
|
||||
weight_is_symmetric = default_wgt_qtype in symmetric_wgt_qtypes
|
||||
|
||||
if input_is_16bit and default_wgt_qtype in q8_qtypes:
|
||||
self.assertIn("TensorQuantOverrides", qnn_config.extra_options)
|
||||
self.assertIn("weight", qnn_config.extra_options["TensorQuantOverrides"])
|
||||
self.assertEqual(
|
||||
qnn_config.extra_options["TensorQuantOverrides"]["weight"],
|
||||
[
|
||||
{
|
||||
"quant_type": default_wgt_qtype,
|
||||
"symmetric": weight_is_symmetric,
|
||||
}
|
||||
],
|
||||
)
|
||||
elif init_overrides:
|
||||
self.assertIn("TensorQuantOverrides", qnn_config.extra_options)
|
||||
self.assertNotIn("weight", qnn_config.extra_options["TensorQuantOverrides"])
|
||||
|
||||
self.assertEqual(weight_is_symmetric, qnn_config.extra_options["WeightSymmetric"])
|
||||
self.assertNotIn("bias", qnn_config.extra_options["TensorQuantOverrides"])
|
||||
|
||||
def test_get_qnn_qdq_config_ext_data(self):
|
||||
"""
|
||||
|
|
@ -613,6 +1028,7 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
|
|||
)
|
||||
|
||||
qnn_config = get_qnn_qdq_config("add_ext_data.onnx", DummyDataReader(self.activations))
|
||||
self.assertEqual(set(qnn_config.op_types_to_quantize), {"Add"})
|
||||
self.assertTrue(qnn_config.use_external_data_format)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue