mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add quantization tool to python package (#4458)
* Add quantization tool to python package
This commit is contained in:
parent
0ca4f7eb30
commit
5dc7339be6
4 changed files with 51 additions and 0 deletions
|
|
@ -186,6 +186,11 @@ file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
|
|||
file(GLOB onnxruntime_python_tools_featurizers_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/featurizer_ops/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_quantization_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/quantization/*.py"
|
||||
)
|
||||
list(REMOVE_ITEM onnxruntime_python_quantization_src
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/quantization/test_calibrate.py")
|
||||
file(GLOB onnxruntime_python_datasets_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/datasets/*.py"
|
||||
)
|
||||
|
|
@ -204,6 +209,7 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/datasets
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/tools
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/tools/featurizer_ops
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/quantization
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${ONNXRUNTIME_ROOT}/__init__.py
|
||||
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/
|
||||
|
|
@ -243,6 +249,9 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_tools_featurizers_src}
|
||||
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/tools/featurizer_ops/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_quantization_src}
|
||||
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/quantization/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${REPO_ROOT}/VERSION_NUMBER
|
||||
$<TARGET_FILE_DIR:${test_data_target}>
|
||||
|
|
|
|||
2
onnxruntime/python/tools/quantization/__init__.py
Normal file
2
onnxruntime/python/tools/quantization/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from onnxruntime.quantization.quantize import quantize
|
||||
from onnxruntime.quantization.quantize import QuantizationMode
|
||||
|
|
@ -295,7 +295,46 @@ class ONNXQuantizer:
|
|||
# Map of all original value names to quantized value names
|
||||
self.quantized_value_map = {}
|
||||
|
||||
def replace_gemm_with_matmul(self):
|
||||
nodes_to_remove = []
|
||||
nodes_to_add = []
|
||||
for node in self.model.graph.node:
|
||||
if node.op_type == 'Gemm':
|
||||
alpha = 1.0
|
||||
beta = 1.0
|
||||
transA = 0
|
||||
transB = 0
|
||||
for attr in node.attribute:
|
||||
if attr.name == 'alpha':
|
||||
alpha = onnx.helper.get_attribute_value(attr)
|
||||
elif attr.name == 'beta':
|
||||
beta = onnx.helper.get_attribute_value(attr)
|
||||
elif attr.name == 'transA':
|
||||
transA = onnx.helper.get_attribute_value(attr)
|
||||
elif attr.name == 'transB':
|
||||
transB = onnx.helper.get_attribute_value(attr)
|
||||
if alpha == 1.0 and beta == 1.0 and transA == 0 and transB == 0:
|
||||
matmul_node = onnx.helper.make_node(
|
||||
'MatMul',
|
||||
[node.input[0], node.input[1]],
|
||||
[node.output[0]+'_MatMul'],
|
||||
name=node.output[0]+'_MatMul')
|
||||
|
||||
add_node = onnx.helper.make_node(
|
||||
'Add',
|
||||
inputs=[node.output[0]+'_MatMul', node.input[2]],
|
||||
outputs=node.output,
|
||||
name=node.output[0]+'_Add')
|
||||
|
||||
nodes_to_remove.extend([node])
|
||||
nodes_to_add.extend([matmul_node, add_node])
|
||||
|
||||
self.model.graph.node.extend(nodes_to_add)
|
||||
for node in nodes_to_remove:
|
||||
self.model.graph.node.remove(node)
|
||||
|
||||
def quantize_model(self):
|
||||
self.replace_gemm_with_matmul()
|
||||
# Create a new topologically sorted list for quantizing a model
|
||||
new_list = []
|
||||
for node in self.model.graph.node:
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -222,6 +222,7 @@ packages = [
|
|||
'onnxruntime.capi.training',
|
||||
'onnxruntime.datasets',
|
||||
'onnxruntime.tools',
|
||||
'onnxruntime.quantization',
|
||||
]
|
||||
|
||||
package_data = {}
|
||||
|
|
|
|||
Loading…
Reference in a new issue