mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
weight matching (#12607)
QDQ loss debug - Weights Matching Part 2 of QDQ loss debugging tool: given a float model and its qdq model, return the matching of all weight tensors and their corresponding dequantized weights from the qdq model.
This commit is contained in:
parent
8a038b9b0c
commit
f2db6bb293
16 changed files with 258 additions and 63 deletions
|
|
@ -12,6 +12,7 @@ from onnx import onnx_pb as onnx_proto
|
|||
|
||||
from .onnx_model import ONNXModel
|
||||
from .quant_utils import (
|
||||
TENSOR_NAME_QUANT_SUFFIX,
|
||||
QuantizationMode,
|
||||
QuantizedValue,
|
||||
QuantizedValueType,
|
||||
|
|
@ -580,7 +581,7 @@ class ONNXQuantizer:
|
|||
:return: List of newly created nodes in NodeProto format.
|
||||
"""
|
||||
input_name = node.input[input_index]
|
||||
output_name = input_name + "_quantized"
|
||||
output_name = input_name + TENSOR_NAME_QUANT_SUFFIX
|
||||
ql_node_name = input_name + "_QuantizeLinear"
|
||||
|
||||
if (given_scale_name is not None) and (given_zp_name is not None):
|
||||
|
|
@ -663,7 +664,7 @@ class ONNXQuantizer:
|
|||
# get bias
|
||||
bias_initializer = find_by_name(bias_name, self.model.initializer())
|
||||
bias_data = tensor_proto_to_array(bias_initializer)
|
||||
quantized_bias_name = bias_name + "_quantized"
|
||||
quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX
|
||||
|
||||
# get scale for input
|
||||
if input_name in self.quantized_value_map:
|
||||
|
|
@ -847,7 +848,7 @@ class ONNXQuantizer:
|
|||
quantized_value.scale_name,
|
||||
)
|
||||
|
||||
q_weight_name = weight.name + "_quantized"
|
||||
q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX
|
||||
zp_name = weight.name + "_zero_point"
|
||||
scale_name = weight.name + "_scale"
|
||||
|
||||
|
|
@ -933,7 +934,7 @@ class ONNXQuantizer:
|
|||
channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims)
|
||||
quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis)
|
||||
|
||||
q_weight_name = weight_name + "_quantized"
|
||||
q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
|
||||
zp_name = weight_name + "_zero_point"
|
||||
scale_name = weight_name + "_scale"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import onnx
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from .base_operator import QuantOperatorBase
|
||||
from .qdq_base_operator import QDQOperatorBase
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ class QLinearActivation(QuantOperatorBase):
|
|||
if not data_found or quantized_input_names is None:
|
||||
return super().quantize()
|
||||
|
||||
qlinear_activation_output = node.output[0] + "_quantized"
|
||||
qlinear_activation_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
|
||||
qlinear_activation_name = ""
|
||||
if node.name != "":
|
||||
qlinear_activation_name = node.name + "_quant"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import onnx
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from .base_operator import QuantOperatorBase
|
||||
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ class QLinearBinaryOp(QuantOperatorBase):
|
|||
if not data_found or quantized_input_names is None:
|
||||
return super().quantize()
|
||||
|
||||
qlinear_binary_math_output = node.output[0] + "_quantized"
|
||||
qlinear_binary_math_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
|
||||
qlinear_binary_math_name = node.name + "_quant" if node.name != "" else ""
|
||||
|
||||
kwargs = {}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import onnx
|
||||
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from .base_operator import QuantOperatorBase
|
||||
from .qdq_base_operator import QDQOperatorBase
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ class QLinearConcat(QuantOperatorBase):
|
|||
quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
|
||||
quantized_output_value = QuantizedValue(
|
||||
node.output[0],
|
||||
node.output[0] + "_quantized",
|
||||
node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
|
||||
output_scale_name,
|
||||
output_zp_name,
|
||||
quantized_input_value.value_type,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import onnx
|
|||
from onnx import onnx_pb as onnx_proto
|
||||
|
||||
from ..quant_utils import (
|
||||
TENSOR_NAME_QUANT_SUFFIX,
|
||||
BiasToQuantize,
|
||||
QuantizedValue,
|
||||
QuantizedValueType,
|
||||
|
|
@ -168,7 +169,7 @@ class QLinearConv(QuantOperatorBase):
|
|||
quantized_bias_name = self.quantizer.quantize_bias_static(node.input[2], node.input[0], node.input[1])
|
||||
bias_present = True
|
||||
|
||||
qlinear_conv_output = node.output[0] + "_quantized"
|
||||
qlinear_conv_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
|
||||
qlinear_conv_name = qlinear_conv_name = node.name + "_quant" if node.name != "" else ""
|
||||
|
||||
kwargs = {}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from ..quant_utils import QuantizedValue, QuantizedValueType
|
||||
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType
|
||||
from .base_operator import QuantOperatorBase
|
||||
from .qdq_base_operator import QDQOperatorBase
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ class Direct8BitOp(QuantOperatorBase):
|
|||
|
||||
quantized_output_value = QuantizedValue(
|
||||
node.output[0],
|
||||
node.output[0] + "_quantized",
|
||||
node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
|
||||
quantized_input_value.scale_name,
|
||||
quantized_input_value.zp_name,
|
||||
quantized_input_value.value_type,
|
||||
|
|
@ -51,7 +51,7 @@ class Direct8BitOp(QuantOperatorBase):
|
|||
# Create an entry for output quantized value
|
||||
quantized_output_value = QuantizedValue(
|
||||
node.output[0],
|
||||
node.output[0] + "_quantized",
|
||||
node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
|
||||
scale_names[0],
|
||||
zero_point_names[0],
|
||||
QuantizedValueType.Input,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from ..quant_utils import QuantizedValue, QuantizedValueType
|
||||
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType
|
||||
from .base_operator import QuantOperatorBase
|
||||
from .qdq_base_operator import QDQOperatorBase
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ class GatherQuant(QuantOperatorBase):
|
|||
if quantized_input_names is None:
|
||||
return super().quantize()
|
||||
|
||||
gather_new_output = node.output[0] + "_quantized"
|
||||
gather_new_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
|
||||
|
||||
# Create an entry for this quantized value
|
||||
q_output = QuantizedValue(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import onnx
|
||||
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from .base_operator import QuantOperatorBase
|
||||
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ class QGlobalAveragePool(QuantOperatorBase):
|
|||
output_zp_name = output_zp_name_from_parameter if data_found else quantized_input_value.zp_name
|
||||
quantized_output_value = QuantizedValue(
|
||||
node.output[0],
|
||||
node.output[0] + "_quantized",
|
||||
node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
|
||||
output_scale_name,
|
||||
output_zp_name,
|
||||
QuantizedValueType.Input,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,15 @@ import numpy as np
|
|||
import onnx
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, find_by_name, get_mul_node, ms_domain
|
||||
from ..quant_utils import (
|
||||
TENSOR_NAME_QUANT_SUFFIX,
|
||||
QuantizedValue,
|
||||
QuantizedValueType,
|
||||
attribute_to_kwarg,
|
||||
find_by_name,
|
||||
get_mul_node,
|
||||
ms_domain,
|
||||
)
|
||||
from .base_operator import QuantOperatorBase
|
||||
from .matmul import QOpMatMul
|
||||
from .qdq_base_operator import QDQOperatorBase
|
||||
|
|
@ -85,7 +93,7 @@ class QLinearGemm(QOpMatMul):
|
|||
node.input[2], node.input[0], node.input[1], get_beta(self.node)
|
||||
)
|
||||
|
||||
qgemm_output = node.output[0] + "_quantized"
|
||||
qgemm_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
|
||||
qgemm_name = qgemm_name = node.name + "_quant" if node.name != "" else ""
|
||||
|
||||
kwargs = {}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import itertools
|
|||
import onnx
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, find_by_name, get_mul_node
|
||||
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, find_by_name, get_mul_node
|
||||
from .base_operator import QuantOperatorBase
|
||||
from .qdq_base_operator import QDQOperatorBase
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ class QLinearMatMul(QOpMatMul):
|
|||
if not data_found or quantized_input_names is None:
|
||||
return super().quantize()
|
||||
|
||||
qlinear_matmul_output = node.output[0] + "_quantized"
|
||||
qlinear_matmul_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
|
||||
qlinear_matmul_name = node.name + "_quant" if node.name != "" else ""
|
||||
|
||||
qlinear_matmul_inputs = []
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
import numpy as np
|
||||
import onnx
|
||||
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, quantize_nparray
|
||||
from ..quant_utils import (
|
||||
TENSOR_NAME_QUANT_SUFFIX,
|
||||
QuantizedValue,
|
||||
QuantizedValueType,
|
||||
attribute_to_kwarg,
|
||||
quantize_nparray,
|
||||
)
|
||||
from .base_operator import QuantOperatorBase
|
||||
|
||||
|
||||
|
|
@ -46,7 +52,7 @@ class QPad(QuantOperatorBase):
|
|||
scale_value,
|
||||
zp_value,
|
||||
)
|
||||
quantized_padding_constant_name = node.input[2] + "_quantized"
|
||||
quantized_padding_constant_name = node.input[2] + TENSOR_NAME_QUANT_SUFFIX
|
||||
quantized_padding_constant_initializer = onnx.numpy_helper.from_array(
|
||||
quantized_padding_constant_array,
|
||||
quantized_padding_constant_name,
|
||||
|
|
@ -72,7 +78,7 @@ class QPad(QuantOperatorBase):
|
|||
# Create an entry for output quantized value
|
||||
quantized_output_value = QuantizedValue(
|
||||
node.output[0],
|
||||
node.output[0] + "_quantized",
|
||||
node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
|
||||
quantized_input_value.scale_name,
|
||||
quantized_input_value.zp_name,
|
||||
QuantizedValueType.Input,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import onnx
|
||||
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from .base_operator import QuantOperatorBase
|
||||
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ class QLinearPool(QuantOperatorBase):
|
|||
return super().quantize()
|
||||
|
||||
# Create an entry for output quantized value.
|
||||
qlinear_output_name = node.output[0] + "_quantized"
|
||||
qlinear_output_name = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
|
||||
quantized_output_value = QuantizedValue(
|
||||
node.output[0],
|
||||
qlinear_output_name,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import onnx
|
||||
|
||||
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
||||
from .base_operator import QuantOperatorBase
|
||||
from .qdq_base_operator import QDQOperatorBase
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class QLinearSoftmax(QuantOperatorBase):
|
|||
return super().quantize()
|
||||
|
||||
# Create an entry for output quantized value.
|
||||
qlinear_output_name = node.output[0] + "_quantized"
|
||||
qlinear_output_name = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
|
||||
quantized_output_value = QuantizedValue(
|
||||
node.output[0],
|
||||
qlinear_output_name,
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ is a list of tensors, one from each model run
|
|||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
|
@ -50,7 +51,16 @@ from onnx import ModelProto, TensorProto, helper, numpy_helper
|
|||
import onnxruntime
|
||||
|
||||
from .calibrate import CalibraterBase, CalibrationDataReader
|
||||
from .quant_utils import DEQUANT_OUTPUT_SUFFIX, QUANT_INPUT_SUFFIX, clone_model_with_shape_infer
|
||||
from .onnx_model import ONNXModel
|
||||
from .quant_utils import (
|
||||
DEQUANT_OP_NAME,
|
||||
DEQUANT_OUTPUT_SUFFIX,
|
||||
QUANT_INPUT_SUFFIX,
|
||||
TENSOR_NAME_QUANT_SUFFIX,
|
||||
clone_model_with_shape_infer,
|
||||
find_by_name,
|
||||
load_model,
|
||||
)
|
||||
|
||||
_TENSOR_SAVE_POSTFIX = "_ReshapedSavedOutput"
|
||||
_TENSOR_SAVE_POSTFIX_LEN = len(_TENSOR_SAVE_POSTFIX)
|
||||
|
|
@ -222,3 +232,94 @@ def create_activation_matching(
|
|||
act_values["float"] = float_acts
|
||||
|
||||
return qdq_cmp
|
||||
|
||||
|
||||
def _run_dequantize_linear(
|
||||
weight_tensor: numpy.ndarray, weight_scale: numpy.ndarray, weight_zp: numpy.ndarray, channel_axis: int
|
||||
) -> Optional[numpy.ndarray]:
|
||||
assert weight_scale.shape == weight_zp.shape
|
||||
if weight_zp.size == 1:
|
||||
return (weight_tensor - weight_zp) * weight_scale
|
||||
|
||||
assert weight_zp.ndim == 1
|
||||
reshape_dims = list(weight_tensor.shape) # deep copy
|
||||
reshape_dims[channel_axis] = 1 # only one per channel for reshape
|
||||
channel_count = weight_tensor.shape[channel_axis]
|
||||
dequantized_weights = None
|
||||
for i in range(channel_count):
|
||||
per_channel_data = weight_tensor.take(i, channel_axis)
|
||||
dequantized_per_channel_data = (per_channel_data - weight_zp[i]) * weight_scale[i]
|
||||
if i == 0:
|
||||
dequantized_weights = numpy.asarray(dequantized_per_channel_data).reshape(reshape_dims)
|
||||
else:
|
||||
channel_weights = numpy.asarray(dequantized_per_channel_data).reshape(reshape_dims)
|
||||
dequantized_weights = numpy.concatenate((dequantized_weights, channel_weights), channel_axis)
|
||||
|
||||
if dequantized_weights is None:
|
||||
return None
|
||||
|
||||
dequantized_weights.reshape(weight_tensor.shape)
|
||||
return dequantized_weights
|
||||
|
||||
|
||||
def create_weight_matching(float_model_path: str, qdq_model_path: str) -> Dict[str, Dict[str, numpy.ndarray]]:
|
||||
"""Comparing weight values to help debugging accuracy loss due to quantization.
|
||||
|
||||
This functions takes the float model and the qdq model, and provides a data structure for comparing
|
||||
their corresponding weights to locate quantization errors
|
||||
|
||||
Arg:
|
||||
float_model_path: Path points to the float point model.
|
||||
qdq_model_path: Path points to the qdq model.
|
||||
|
||||
Returns:
|
||||
Dict for comparing weight tensors. E.g.
|
||||
```
|
||||
qdq_weight_cmp = create_weight_matching(float_model, qdq_model)
|
||||
print(qdq_weight_cmp['activation1']['float'][0])
|
||||
print(qdq_weight_cmp['activation1']['dequantized'][0])
|
||||
```
|
||||
"""
|
||||
float_onnx_model = ONNXModel(load_model(Path(float_model_path), need_optimize=False))
|
||||
qdq_onnx_model = ONNXModel(load_model(Path(qdq_model_path), need_optimize=False))
|
||||
|
||||
matched_weights: Dict[str, Dict[str, numpy.ndarray]] = {}
|
||||
initializers = qdq_onnx_model.initializer()
|
||||
for node in qdq_onnx_model.nodes():
|
||||
if node.op_type != DEQUANT_OP_NAME:
|
||||
continue # Only care about DQ node
|
||||
weight_name: str = node.input[0]
|
||||
weight_values = find_by_name(weight_name, initializers)
|
||||
if not weight_values:
|
||||
continue # Only care about DQ node with const inputs
|
||||
if not weight_name.endswith(TENSOR_NAME_QUANT_SUFFIX):
|
||||
logging.error(f"Model Error in '{qdq_model_path}': Dequantized tensor name '{weight_name}' not recognized!")
|
||||
continue
|
||||
|
||||
axis = -1
|
||||
for attr in node.attribute:
|
||||
if attr.name == "axis":
|
||||
axis = attr.i
|
||||
|
||||
weight_tensor = numpy_helper.to_array(weight_values)
|
||||
weight_scale = numpy_helper.to_array(find_by_name(node.input[1], initializers))
|
||||
if len(node.input) > 2:
|
||||
weight_zp = numpy_helper.to_array(find_by_name(node.input[2], initializers))
|
||||
else:
|
||||
weight_zp = numpy.zeros(weight_scale.shape, dtype=numpy.int32)
|
||||
|
||||
# Perform dequantization:
|
||||
weight_quant = _run_dequantize_linear(weight_tensor, weight_scale, weight_zp, channel_axis=axis)
|
||||
weight_name = weight_name[: -len(TENSOR_NAME_QUANT_SUFFIX)]
|
||||
if weight_quant is None:
|
||||
logging.error(f"Model Error in '{qdq_model_path}': '{weight_name}' per-channel quantization on 0 channel")
|
||||
continue
|
||||
|
||||
float_values = find_by_name(weight_name, float_onnx_model.initializer())
|
||||
if not float_values:
|
||||
logging.error(f"Model Error in '{float_model_path}': weight tensor '{weight_name}' not found!")
|
||||
continue
|
||||
weight_float = numpy_helper.to_array(float_values)
|
||||
matched_weights[weight_name] = {"float": weight_float, "dequantized": weight_quant}
|
||||
|
||||
return matched_weights
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ QUANT_OP_NAME = "QuantizeLinear"
|
|||
QUANT_INPUT_SUFFIX = "_QuantizeLinear_Input"
|
||||
DEQUANT_OP_NAME = "DequantizeLinear"
|
||||
DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output"
|
||||
TENSOR_NAME_QUANT_SUFFIX = "_quantized"
|
||||
|
||||
|
||||
type_to_name = {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ from typing import Dict, List
|
|||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import TensorProto, helper, numpy_helper
|
||||
from onnx import TensorProto, helper
|
||||
from op_test_utils import generate_random_initializer
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static
|
||||
|
|
@ -21,19 +22,11 @@ from onnxruntime.quantization.calibrate import CalibrationDataReader
|
|||
from onnxruntime.quantization.qdq_loss_debug import (
|
||||
collect_activations,
|
||||
create_activation_matching,
|
||||
create_weight_matching,
|
||||
modify_model_output_intermediate_tensors,
|
||||
)
|
||||
|
||||
|
||||
def generate_input_initializer(tensor_shape, tensor_dtype, input_name):
|
||||
"""
|
||||
Helper function to generate initializers for test inputs
|
||||
"""
|
||||
tensor = np.random.normal(0, 0.3, tensor_shape).astype(tensor_dtype)
|
||||
init = numpy_helper.from_array(tensor, input_name)
|
||||
return init
|
||||
|
||||
|
||||
def construct_test_model1(test_model_path: str, activations_as_outputs=False):
|
||||
""" Create an ONNX model shaped as:
|
||||
```
|
||||
|
|
@ -62,32 +55,28 @@ def construct_test_model1(test_model_path: str, activations_as_outputs=False):
|
|||
x4_output = helper.make_tensor_value_info("Conv2Out", TensorProto.FLOAT, [1, 3, 1, 3])
|
||||
x5_output = helper.make_tensor_value_info("Conv3Out", TensorProto.FLOAT, [1, 3, 1, 3])
|
||||
x6_output = helper.make_tensor_value_info("AddOut", TensorProto.FLOAT, [1, 3, 1, 3])
|
||||
w1 = generate_input_initializer([3, 3, 1, 1], np.float32, "W1")
|
||||
b1 = generate_input_initializer([3], np.float32, "B1")
|
||||
w3 = generate_input_initializer([3, 3, 1, 1], np.float32, "W3")
|
||||
b3 = generate_input_initializer([3], np.float32, "B3")
|
||||
w5 = generate_input_initializer([3, 3, 1, 1], np.float32, "W5")
|
||||
b5 = generate_input_initializer([3], np.float32, "B5")
|
||||
relu_node_1 = helper.make_node("Relu", ["input"], ["Relu1Out"], name="Relu1")
|
||||
conv_node_1 = helper.make_node("Conv", ["Relu1Out", "W1", "B1"], ["Conv1Out"], name="Conv1")
|
||||
relu_node_2 = helper.make_node("Relu", ["Conv1Out"], ["Relu2Out"], name="Relu2")
|
||||
conv_node_2 = helper.make_node("Conv", ["Relu2Out", "W3", "B3"], ["Conv2Out"], name="Conv2")
|
||||
conv_node_3 = helper.make_node("Conv", ["Relu1Out", "W5", "B5"], ["Conv3Out"], name="Conv3")
|
||||
add_node = helper.make_node("Add", ["Conv2Out", "Conv3Out"], ["AddOut"], name="Add")
|
||||
|
||||
initializer = []
|
||||
initializer.append(generate_random_initializer("W1", [3, 3, 1, 1], np.float32))
|
||||
initializer.append(generate_random_initializer("B1", [3], np.float32))
|
||||
initializer.append(generate_random_initializer("W3", [3, 3, 1, 1], np.float32))
|
||||
initializer.append(generate_random_initializer("B3", [3], np.float32))
|
||||
initializer.append(generate_random_initializer("W5", [3, 3, 1, 1], np.float32))
|
||||
initializer.append(generate_random_initializer("B5", [3], np.float32))
|
||||
|
||||
nodes = []
|
||||
nodes.append(helper.make_node("Relu", ["input"], ["Relu1Out"], name="Relu1"))
|
||||
nodes.append(helper.make_node("Conv", ["Relu1Out", "W1", "B1"], ["Conv1Out"], name="Conv1"))
|
||||
nodes.append(helper.make_node("Relu", ["Conv1Out"], ["Relu2Out"], name="Relu2"))
|
||||
nodes.append(helper.make_node("Conv", ["Relu2Out", "W3", "B3"], ["Conv2Out"], name="Conv2"))
|
||||
nodes.append(helper.make_node("Conv", ["Relu1Out", "W5", "B5"], ["Conv3Out"], name="Conv3"))
|
||||
nodes.append(helper.make_node("Add", ["Conv2Out", "Conv3Out"], ["AddOut"], name="Add"))
|
||||
|
||||
# we are keeping all tensors in the output anyway for verification purpose
|
||||
outputs = [x6_output]
|
||||
if activations_as_outputs:
|
||||
outputs.extend([x1_output, x2_output, x3_output, x4_output, x5_output])
|
||||
graph = helper.make_graph(
|
||||
[relu_node_1, conv_node_1, relu_node_2, conv_node_2, conv_node_3, add_node], "test_graph_4", [input_vi], outputs
|
||||
)
|
||||
graph.initializer.add().CopyFrom(w1)
|
||||
graph.initializer.add().CopyFrom(b1)
|
||||
graph.initializer.add().CopyFrom(w3)
|
||||
graph.initializer.add().CopyFrom(b3)
|
||||
graph.initializer.add().CopyFrom(w5)
|
||||
graph.initializer.add().CopyFrom(b5)
|
||||
graph = helper.make_graph(nodes, "test_graph_relu_conv", [input_vi], outputs, initializer=initializer)
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
onnx.save(model, test_model_path)
|
||||
|
||||
|
|
@ -95,13 +84,13 @@ def construct_test_model1(test_model_path: str, activations_as_outputs=False):
|
|||
class TestDataReader(CalibrationDataReader):
|
||||
"""Random Data Input Generator"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, input_shape=[1, 3, 1, 3]):
|
||||
self.preprocess_flag = True
|
||||
self.enum_data_dicts = []
|
||||
self.count = 2
|
||||
self.input_data_list = []
|
||||
for _ in range(self.count):
|
||||
self.input_data_list.append(np.random.normal(0, 0.33, [1, 3, 1, 3]).astype(np.float32))
|
||||
self.input_data_list.append(np.random.normal(0, 0.33, input_shape).astype(np.float32))
|
||||
|
||||
def get_next(self):
|
||||
if self.preprocess_flag:
|
||||
|
|
@ -215,6 +204,94 @@ class TestSaveActivations(unittest.TestCase):
|
|||
|
||||
self.assertFalse(compare_dict.get("Conv1Out"))
|
||||
|
||||
def test_create_weight_matching(self):
|
||||
# Setup: create float model:
|
||||
float_model_path = str(Path(self._tmp_model_dir.name) / "float_model3.onnx")
|
||||
construct_test_model1(float_model_path, activations_as_outputs=False)
|
||||
|
||||
# Setup: create qdq model:
|
||||
data_reader = TestDataReader()
|
||||
qdq_model_path = str(Path(self._tmp_model_dir.name) / "qdq_model3.onnx")
|
||||
quantize_static(
|
||||
float_model_path,
|
||||
qdq_model_path,
|
||||
data_reader,
|
||||
quant_format=QuantFormat.QDQ,
|
||||
per_channel=False,
|
||||
reduce_range=False,
|
||||
activation_type=QuantType.QInt8,
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
# Call function under test and verify all weights are present
|
||||
matched_weights = create_weight_matching(float_model_path, qdq_model_path)
|
||||
weight_names = ["W1", "W3", "W5", "B1", "B3", "B5"]
|
||||
for weight_name in weight_names:
|
||||
float_array = matched_weights[weight_name]["float"]
|
||||
dq_array = matched_weights[weight_name]["dequantized"]
|
||||
self.assertEqual(float_array.shape, dq_array.shape)
|
||||
|
||||
def test_create_weight_matching_per_channel(self):
|
||||
|
||||
# float model
|
||||
# (input)
|
||||
# |
|
||||
# Add
|
||||
# / | \
|
||||
# MatMul MatMul MatMul
|
||||
# | | |
|
||||
# (output)(output)(output)
|
||||
float_model_path = str(Path(self._tmp_model_dir.name) / "float_model4.onnx")
|
||||
initializers = []
|
||||
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [5, 5])
|
||||
output_tensor1 = helper.make_tensor_value_info("M", TensorProto.FLOAT, [5, 5])
|
||||
output_tensor2 = helper.make_tensor_value_info("N", TensorProto.FLOAT, [5, 5])
|
||||
output_tensor3 = helper.make_tensor_value_info("O", TensorProto.FLOAT, [5, 5])
|
||||
|
||||
add_weight_data = np.random.normal(0, 0.1, [5, 5]).astype(np.float32)
|
||||
initializers.append(onnx.numpy_helper.from_array(add_weight_data, name="P"))
|
||||
matmul_weight_data_1 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32)
|
||||
initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_1, name="Q"))
|
||||
matmul_weight_data_2 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32)
|
||||
initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_2, name="R"))
|
||||
initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_2, name="S"))
|
||||
|
||||
add_node = onnx.helper.make_node("Add", ["input", "P"], ["T"], name="Add")
|
||||
matmul_node_1 = onnx.helper.make_node("MatMul", ["T", "Q"], ["M"], name="MatMul1")
|
||||
matmul_node_2 = onnx.helper.make_node("MatMul", ["T", "R"], ["N"], name="MatMul2")
|
||||
matmul_node_3 = onnx.helper.make_node("MatMul", ["T", "S"], ["O"], name="MatMul3")
|
||||
|
||||
graph = helper.make_graph(
|
||||
[add_node, matmul_node_1, matmul_node_2, matmul_node_3],
|
||||
"QDQ_Test",
|
||||
[input_tensor],
|
||||
[output_tensor1, output_tensor2, output_tensor3],
|
||||
initializer=initializers,
|
||||
)
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
onnx.save(model, float_model_path)
|
||||
|
||||
# Setup: create qdq model:
|
||||
qdq_model_path = str(Path(self._tmp_model_dir.name) / "qdq_model4.onnx")
|
||||
quantize_static(
|
||||
float_model_path,
|
||||
qdq_model_path,
|
||||
TestDataReader([5, 5]),
|
||||
quant_format=QuantFormat.QDQ,
|
||||
per_channel=True,
|
||||
reduce_range=False,
|
||||
activation_type=QuantType.QInt8,
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
# Call function under test and verify all weights are present
|
||||
matched_weights = create_weight_matching(float_model_path, qdq_model_path)
|
||||
weight_names = ["P", "Q", "R", "S"]
|
||||
for weight_name in weight_names:
|
||||
float_array = matched_weights[weight_name]["float"]
|
||||
dq_array = matched_weights[weight_name]["dequantized"]
|
||||
self.assertEqual(float_array.shape, dq_array.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue