mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Some initializers are added without raw=True flag. That causes those tensors cannot be saved to external data. If those tensors exceed 2GB in total, optimized model cannot be saved due to protobuf limit. This change will save attention weights and bias in raw data. Note: it is optional to use raw data for shape tensor since they are tiny. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/17212 https://github.com/microsoft/onnxruntime/issues/15349
166 lines
6.6 KiB
Python
166 lines
6.6 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
from logging import getLogger
|
|
from typing import Dict, List
|
|
|
|
from fusion_base import Fusion
|
|
from fusion_utils import FusionUtils
|
|
from onnx import NodeProto, TensorProto, helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionTranspose(Fusion):
|
|
def __init__(self, model: OnnxModel):
|
|
super().__init__(model, "Transpose", "Transpose")
|
|
|
|
def fuse(
|
|
self,
|
|
transpose_node: NodeProto,
|
|
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
output_name_to_node: Dict[str, NodeProto],
|
|
):
|
|
"""
|
|
Note that onnxruntime will do comprehensive transpose optimization after loading model.
|
|
The purpose of this fusion is to make graph clean before running onnxruntime.
|
|
|
|
Case 1:
|
|
(input)-->Transpose(perm=a)-->Transpose(perm=b)-->
|
|
After:
|
|
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
|
|
|
|
+----->Transpose(perm=a*b)-->
|
|
|
|
Case 2 (Cast has only one child):
|
|
(input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
|
|
After:
|
|
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
|
|
|
|
+----->Cast --> Transpose(perm=a*b)-->
|
|
"""
|
|
transpose_b = transpose_node
|
|
if transpose_b.input[0] not in output_name_to_node:
|
|
return
|
|
|
|
transpose_a = output_name_to_node[transpose_b.input[0]]
|
|
if transpose_a.op_type != "Cast":
|
|
cast_node = None
|
|
else:
|
|
cast_node = transpose_a
|
|
|
|
cast_children = self.model.get_children(cast_node, input_name_to_nodes)
|
|
if cast_children and len(cast_children) > 1:
|
|
return
|
|
|
|
if cast_node.input[0] not in output_name_to_node:
|
|
return
|
|
|
|
transpose_a = output_name_to_node[cast_node.input[0]]
|
|
|
|
if transpose_a.op_type != "Transpose":
|
|
return
|
|
|
|
permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
|
|
assert isinstance(permutation, list)
|
|
|
|
parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
|
|
assert isinstance(parent_permutation, list)
|
|
|
|
assert len(parent_permutation) == len(permutation)
|
|
|
|
output_permutation = []
|
|
for _j, index in enumerate(permutation):
|
|
output_permutation.append(parent_permutation[index])
|
|
|
|
if cast_node is None:
|
|
if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
|
|
self.nodes_to_remove.append(transpose_a)
|
|
else:
|
|
if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
|
|
self.nodes_to_remove.append(transpose_a)
|
|
transpose_b.ClearField("attribute")
|
|
transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
|
|
|
|
|
|
class FusionInsertTranspose(Fusion):
|
|
def __init__(self, model: OnnxModel):
|
|
super().__init__(model, "", "GroupNorm")
|
|
|
|
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
|
|
"""Append a Transpose node after an input"""
|
|
node_name = self.model.create_node_name("Transpose")
|
|
if output_name is None:
|
|
output_name = node_name + "_out" + "-" + input_name
|
|
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
|
|
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
|
|
return transpose_node
|
|
|
|
def fuse(
|
|
self,
|
|
group_norm_node: NodeProto,
|
|
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
output_name_to_node: Dict[str, NodeProto],
|
|
):
|
|
"""
|
|
This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with
|
|
another Transpose so that we can get effect of reducing one Transpose after onnxruntime optimization.
|
|
Before:
|
|
--> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
|
|
After:
|
|
--> Gemm --> Unsqueeze(axes=[1]) --> Unsqueeze(axes=[2]) -->Transpose([0,3,1,2]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
|
|
"""
|
|
gemm_path = self.model.match_parent_path(
|
|
group_norm_node, ["Transpose", "Add", "Unsqueeze", "Unsqueeze", "Gemm"], [0, 0, None, 0, 0]
|
|
)
|
|
if gemm_path is None:
|
|
return
|
|
transpose, add, unsqueeze_3, unsqueeze_2, gemm = gemm_path
|
|
if self.model.find_graph_output(unsqueeze_3.output[0]):
|
|
return
|
|
|
|
permutation = OnnxModel.get_node_attribute(transpose, "perm")
|
|
assert isinstance(permutation, list)
|
|
if permutation != [0, 2, 3, 1]:
|
|
return
|
|
|
|
if not (
|
|
self.model.get_constant_value(unsqueeze_3.input[1]) == 3
|
|
and self.model.get_constant_value(unsqueeze_2.input[1]) == 2
|
|
and len(self.model.get_children(gemm, input_name_to_nodes)) == 1
|
|
and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1
|
|
and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1
|
|
):
|
|
return
|
|
|
|
# Here we use hard-coded name so that it could be shared for the whole model.
|
|
axes_1 = "ort_const_unsqueeze_axes_1"
|
|
if self.model.get_initializer(axes_1) is None:
|
|
self.add_initializer(
|
|
name=axes_1,
|
|
data_type=TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[1],
|
|
raw=False,
|
|
)
|
|
|
|
axes_2 = "ort_const_unsqueeze_axes_2"
|
|
if self.model.get_initializer(axes_2) is None:
|
|
self.add_initializer(
|
|
name=axes_2,
|
|
data_type=TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[2],
|
|
raw=False,
|
|
)
|
|
|
|
unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2"
|
|
unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1"
|
|
transpose_output_name = self.model.create_node_name("Transpose") + "_NCHW"
|
|
self.model.replace_input_of_all_nodes(unsqueeze_3.output[0], transpose_output_name)
|
|
new_transpose = self.create_transpose_node(unsqueeze_3.output[0], [0, 3, 1, 2], transpose_output_name)
|
|
self.model.add_node(new_transpose, self.this_graph_name)
|
|
self.increase_counter("Insert Transpose")
|