mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Attention fusion for stable diffusion clip model (#17445)
Add attention fusion for stable diffusion clip model to improve performance of SD or SDXL
This commit is contained in:
parent
4d753b74a5
commit
29a818caa0
7 changed files with 403 additions and 15 deletions
|
|
@ -33,18 +33,20 @@ def run_model(model_path, all_inputs, use_gpu, disable_optimization):
|
|||
return results, latency_list, output_names
|
||||
|
||||
|
||||
def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4):
|
||||
def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
|
||||
# Validate the output of baseline and treatment, to make sure the results are similar.
|
||||
diff_count = 0
|
||||
max_rel_diff = 0
|
||||
max_abs_diff = 0
|
||||
for test_case_id, results in enumerate(baseline_results):
|
||||
case_passed = True
|
||||
for i in range(len(results)):
|
||||
treatment_output = treatment_results[test_case_id][i]
|
||||
rel_diff = np.amax(np.abs((treatment_output - results[i]) / results[i]))
|
||||
abs_diff = np.amax(np.abs(treatment_output - results[i]))
|
||||
max_rel_diff = max(max_rel_diff, rel_diff)
|
||||
if verbose and abs_diff > atol:
|
||||
print("abs_diff", abs_diff)
|
||||
print("treatment", treatment_output)
|
||||
print("baseline", results[i])
|
||||
|
||||
max_abs_diff = max(max_abs_diff, abs_diff)
|
||||
if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol):
|
||||
if case_passed:
|
||||
|
|
@ -54,7 +56,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4):
|
|||
if verbose:
|
||||
print(f"case {test_case_id} output {i}")
|
||||
print(f"baseline={results[i].tolist()}\ntreatment={treatment_output}")
|
||||
print(f"rel_diff={rel_diff} abs_diff={abs_diff}")
|
||||
print(f"abs_diff={abs_diff}")
|
||||
|
||||
if diff_count == 0:
|
||||
print(
|
||||
|
|
@ -70,8 +72,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4):
|
|||
)
|
||||
|
||||
print(f"maximum absolute difference={max_abs_diff}")
|
||||
|
||||
print(f"maximum relative difference={max_rel_diff}")
|
||||
return max_abs_diff, case_passed
|
||||
|
||||
|
||||
def run_test(
|
||||
|
|
@ -133,7 +134,7 @@ def run_test(
|
|||
print(f"treatment average latency: {statistics.mean(treatment_latency) * 1000} ms")
|
||||
|
||||
# Validate the output of baseline and treatment, to make sure the results are similar.
|
||||
compare(baseline_results, treatment_results, verbose, rtol, atol)
|
||||
return compare(baseline_results, treatment_results, verbose, rtol, atol)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ class FusionAttention(Fusion):
|
|||
def get_add_qk_str(self, add_qk: NodeProto):
|
||||
shape_infer = self.model.infer_runtime_shape(update=True)
|
||||
if shape_infer is None:
|
||||
return
|
||||
return None
|
||||
|
||||
input_0_shape = shape_infer.get_edge_shape(add_qk.input[0])
|
||||
input_1_shape = shape_infer.get_edge_shape(add_qk.input[1])
|
||||
|
|
@ -697,6 +697,7 @@ class FusionAttention(Fusion):
|
|||
present_k: str = "",
|
||||
present_v: str = "",
|
||||
scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Union[NodeProto, None]:
|
||||
"""Create an Attention node.
|
||||
|
||||
|
|
@ -717,6 +718,8 @@ class FusionAttention(Fusion):
|
|||
past_v (str): name of input for past V value
|
||||
present_k (str): name of output to store present K value
|
||||
present_v (str): name of output to store present V value
|
||||
scale: scale before softmax
|
||||
causal: whether it is uni-directional mask.
|
||||
|
||||
Returns:
|
||||
Union[NodeProto, None]: the node created or None if failed.
|
||||
|
|
@ -828,7 +831,7 @@ class FusionAttention(Fusion):
|
|||
|
||||
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
|
||||
if self.use_multi_head_attention:
|
||||
if add_qk_str is not None:
|
||||
if add_qk_str:
|
||||
logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
|
||||
return None
|
||||
|
||||
|
|
@ -864,7 +867,7 @@ class FusionAttention(Fusion):
|
|||
past_kv = self.concat_kv(past_k, past_v)
|
||||
attention_inputs.append(past_kv)
|
||||
|
||||
if add_qk_str is not None:
|
||||
if add_qk_str:
|
||||
# Convert 4d mask from (B,1,M,M) to (B,N,M,M)
|
||||
# B = batch size, M = max sequence length, N = num heads
|
||||
concat_node_name = self.model.create_node_name("Concat")
|
||||
|
|
@ -901,6 +904,9 @@ class FusionAttention(Fusion):
|
|||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
||||
|
||||
if causal:
|
||||
attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)])
|
||||
|
||||
if scale is not None:
|
||||
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
|
||||
|
||||
|
|
|
|||
218
onnxruntime/python/tools/transformers/fusion_attention_clip.py
Normal file
218
onnxruntime/python/tools/transformers/fusion_attention_clip.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from logging import getLogger
|
||||
from typing import Tuple
|
||||
|
||||
from fusion_attention import AttentionMask, FusionAttention
|
||||
from fusion_options import AttentionMaskFormat
|
||||
from onnx import NodeProto
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class FusionAttentionClip(FusionAttention):
|
||||
"""
|
||||
Fuse Attention subgraph of Clip into one Attention node.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: OnnxModel,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
):
|
||||
attention_mask = AttentionMask(model)
|
||||
attention_mask.mask_format = AttentionMaskFormat.NoMask
|
||||
|
||||
super().__init__(
|
||||
model,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
attention_mask,
|
||||
use_multi_head_attention=False,
|
||||
search_op_types=["SkipLayerNormalization"],
|
||||
)
|
||||
|
||||
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
|
||||
"""Detect num_heads and hidden_size for ONNX model from MiDaS
|
||||
Args:
|
||||
reshape_q (NodeProto): reshape node for q
|
||||
Returns:
|
||||
Tuple[int, int]: num_heads and hidden_size
|
||||
"""
|
||||
concat = self.model.match_parent(reshape_q, "Concat", 1)
|
||||
if concat is None or len(concat.input) != 4:
|
||||
return self.num_heads, self.hidden_size
|
||||
|
||||
# The shape is a tensor like [?, ?, num_heads, head_size]
|
||||
num_head_value = self.model.get_constant_value(concat.input[2])
|
||||
if num_head_value is None:
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
if len(num_head_value) != 1 or num_head_value[0] <= 0:
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
num_heads = num_head_value[0]
|
||||
|
||||
head_size_value = self.model.get_constant_value(concat.input[3])
|
||||
if head_size_value is None:
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
if len(head_size_value) != 1 or head_size_value[0] <= 0:
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
head_size = head_size_value[0]
|
||||
|
||||
hidden_size = num_heads * head_size
|
||||
|
||||
if self.num_heads > 0 and num_heads != self.num_heads:
|
||||
if self.num_heads_warning:
|
||||
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
|
||||
self.num_heads_warning = False # Do not show the warning more than once
|
||||
|
||||
if self.hidden_size > 0 and hidden_size != self.hidden_size:
|
||||
if self.hidden_size_warning:
|
||||
logger.warning(
|
||||
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
|
||||
)
|
||||
self.hidden_size_warning = False # Do not show the warning more than once
|
||||
|
||||
return num_heads, hidden_size
|
||||
|
||||
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
||||
skip_input_index = None
|
||||
node_before_layer_norm = None
|
||||
for i in [1, 0]:
|
||||
parent = self.model.match_parent(normalize_node, "SkipLayerNormalization", i)
|
||||
if parent is not None:
|
||||
skip_input_index = i
|
||||
node_before_layer_norm = parent
|
||||
|
||||
root_input = None
|
||||
if node_before_layer_norm is not None:
|
||||
root_input = node_before_layer_norm.output[0]
|
||||
else:
|
||||
# Deal with the first attention after the embedding layer.
|
||||
for i in [0, 1]:
|
||||
node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i)
|
||||
if node_before_layer_norm is None:
|
||||
continue
|
||||
child = self.model.find_first_child_by_type(
|
||||
node_before_layer_norm, "LayerNormalization", input_name_to_nodes, False
|
||||
)
|
||||
if child is None:
|
||||
continue
|
||||
root_input = child.output[0]
|
||||
skip_input_index = i
|
||||
break
|
||||
|
||||
if skip_input_index is None:
|
||||
return
|
||||
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
normalize_node,
|
||||
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
||||
[1 - skip_input_index, None, None, 0, 0, 0],
|
||||
)
|
||||
if qkv_nodes is None:
|
||||
return
|
||||
|
||||
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
|
||||
|
||||
v_nodes = self.model.match_parent_path(
|
||||
matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
|
||||
)
|
||||
if v_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match v path")
|
||||
return
|
||||
(_, _, reshape_v, add_v, matmul_v) = v_nodes
|
||||
|
||||
add_mask_indices = []
|
||||
qk_nodes = self.model.match_parent_path(
|
||||
matmul_qkv,
|
||||
["Softmax", "Reshape", "Add", "Reshape", "MatMul"],
|
||||
[0, 0, 0, None, 0],
|
||||
return_indice=add_mask_indices,
|
||||
)
|
||||
if qk_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match qk path")
|
||||
return
|
||||
assert len(add_mask_indices) == 1
|
||||
causal_mask_input_index = 1 - add_mask_indices[0]
|
||||
|
||||
(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes
|
||||
|
||||
q_nodes = self.model.match_parent_path(
|
||||
matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None]
|
||||
)
|
||||
if q_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match q path")
|
||||
return
|
||||
(_, _transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes
|
||||
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
|
||||
)
|
||||
if k_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match k path")
|
||||
return
|
||||
|
||||
(_transpose_k, _reshape_k, _, _, add_k, matmul_k) = k_nodes
|
||||
if matmul_q.input[0] != root_input or matmul_k.input[0] != root_input or matmul_v.input[0] != root_input:
|
||||
logger.debug("fuse_attention: expect to have same input to q, k and v matmul")
|
||||
return
|
||||
|
||||
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
||||
if num_heads <= 0 or hidden_size <= 0:
|
||||
logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
|
||||
return
|
||||
|
||||
attention_last_node = reshape_qkv
|
||||
|
||||
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
|
||||
# of computing causal mask.
|
||||
causal_mask_nodes = self.model.match_parent_path(
|
||||
add_mask,
|
||||
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
|
||||
[causal_mask_input_index, 0, 0, 0, 0, 0],
|
||||
)
|
||||
if causal_mask_nodes is None:
|
||||
# If the model is exported with batch_size == 1, there is no Concat node
|
||||
causal_mask_nodes = self.model.match_parent_path(
|
||||
add_mask,
|
||||
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
|
||||
[causal_mask_input_index, 0, 0, 0, 0],
|
||||
)
|
||||
if causal_mask_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match causal mask subgraph")
|
||||
return
|
||||
|
||||
new_node = self.create_attention_node(
|
||||
mask_index=None,
|
||||
q_matmul=matmul_q,
|
||||
k_matmul=matmul_k,
|
||||
v_matmul=matmul_v,
|
||||
q_add=add_q,
|
||||
k_add=add_k,
|
||||
v_add=add_v,
|
||||
num_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
input=root_input,
|
||||
output=attention_last_node.output[0],
|
||||
add_qk_str=None,
|
||||
scale=None,
|
||||
causal=True,
|
||||
)
|
||||
if new_node is None:
|
||||
return
|
||||
|
||||
self.nodes_to_add.append(new_node)
|
||||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
|
||||
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
|
||||
|
||||
# Use prune graph to remove nodes since they are shared by all attention nodes.
|
||||
self.prune_graph = True
|
||||
|
|
@ -45,6 +45,9 @@ class FusionOptions:
|
|||
self.enable_gemm_fast_gelu = False
|
||||
self.group_norm_channels_last = True
|
||||
|
||||
if model_type == "clip":
|
||||
self.enable_embed_layer_norm = False
|
||||
|
||||
# Set default to sequence length for BERT model to use fused attention to speed up.
|
||||
# Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
|
||||
self.attention_mask_format = AttentionMaskFormat.AttentionMask
|
||||
|
|
|
|||
|
|
@ -384,8 +384,7 @@ Some kernels are enabled by MIOpen. We hereby thank for the AMD developers' coll
|
|||
|
||||
There are other optimizations might improve the performance or reduce memory footprint:
|
||||
* Export the whole pipeline into a single ONNX model. Currently, there are multiple ONNX models (CLIP, VAE and U-Net etc). Each model uses separated thread pool and memory allocator. Combine them into one model could share thread pool and memory allocator. The end result is more efficient and less memory footprint.
|
||||
* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention using Triton compiler to improve performance.
|
||||
* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention in Windows to improve performance.
|
||||
* Reduce GPU memory footprint by actively deleting buffers for intermediate results.
|
||||
* Attention fusion in CLIP
|
||||
* Safety Checker Optimization
|
||||
* Leverage FP8 in latest GPU
|
||||
|
|
|
|||
|
|
@ -5,15 +5,17 @@
|
|||
|
||||
from logging import getLogger
|
||||
|
||||
from fusion_attention_clip import FusionAttentionClip
|
||||
from onnx import ModelProto
|
||||
from onnx_model_unet import UnetOnnxModel
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class ClipOnnxModel(UnetOnnxModel):
|
||||
class ClipOnnxModel(BertOnnxModel):
|
||||
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
|
||||
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
|
||||
self.clip_attention_fusion = FusionAttentionClip(self, self.hidden_size, self.num_heads)
|
||||
|
||||
def get_fused_operator_statistics(self):
|
||||
"""
|
||||
|
|
@ -31,3 +33,6 @@ class ClipOnnxModel(UnetOnnxModel):
|
|||
|
||||
logger.info(f"Optimized operators:{op_count}")
|
||||
return op_count
|
||||
|
||||
def fuse_attention(self):
|
||||
self.clip_attention_fusion.apply()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parity_utilities import find_transformers_source
|
||||
|
||||
if find_transformers_source():
|
||||
from compare_bert_results import run_test
|
||||
from fusion_options import FusionOptions
|
||||
from optimizer import optimize_model
|
||||
else:
|
||||
from onnxruntime.transformers.compare_bert_results import run_test
|
||||
from onnxruntime.transformers.fusion_options import FusionOptions
|
||||
from onnxruntime.transformers.optimizer import optimize_model
|
||||
|
||||
TINY_MODELS = {
|
||||
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
|
||||
}
|
||||
|
||||
|
||||
class TestStableDiffusionOptimization(unittest.TestCase):
|
||||
def verify_node_count(self, onnx_model, expected_node_count, test_name):
|
||||
for op_type, count in expected_node_count.items():
|
||||
if len(onnx_model.get_nodes_by_op_type(op_type)) != count:
|
||||
print(f"Counters is not expected in test: {test_name}")
|
||||
for op, counter in expected_node_count.items():
|
||||
print(f"{op}: {len(onnx_model.get_nodes_by_op_type(op))} expected={counter}")
|
||||
|
||||
self.assertEqual(len(onnx_model.get_nodes_by_op_type(op_type)), count)
|
||||
|
||||
def verify_clip_optimizer(self, clip_onnx_path, optimized_clip_onnx_path, expected_counters, float16=False):
|
||||
fusion_options = FusionOptions("clip")
|
||||
m = optimize_model(
|
||||
clip_onnx_path,
|
||||
model_type="clip",
|
||||
num_heads=0,
|
||||
hidden_size=0,
|
||||
opt_level=0,
|
||||
optimization_options=fusion_options,
|
||||
use_gpu=True,
|
||||
)
|
||||
self.verify_node_count(m, expected_counters, "test_clip")
|
||||
|
||||
if float16:
|
||||
m.convert_float_to_float16(
|
||||
keep_io_types=True,
|
||||
)
|
||||
print(m.get_operator_statistics())
|
||||
m.save_model_to_file(optimized_clip_onnx_path)
|
||||
|
||||
threshold = 1e-2 if float16 else 3e-3
|
||||
max_abs_diff, passed = run_test(
|
||||
clip_onnx_path,
|
||||
optimized_clip_onnx_path,
|
||||
output_dir=None,
|
||||
batch_size=1,
|
||||
sequence_length=77,
|
||||
use_gpu=True,
|
||||
test_cases=10,
|
||||
seed=1,
|
||||
verbose=False,
|
||||
rtol=1e-1,
|
||||
atol=threshold,
|
||||
input_ids_name="input_ids",
|
||||
segment_ids_name=None,
|
||||
input_mask_name=None,
|
||||
mask_type=0,
|
||||
)
|
||||
|
||||
self.assertLess(max_abs_diff, threshold)
|
||||
self.assertTrue(passed)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_clip_sd(self):
|
||||
save_directory = "tiny-random-stable-diffusion"
|
||||
if os.path.exists(save_directory):
|
||||
shutil.rmtree(save_directory, ignore_errors=True)
|
||||
|
||||
model_type = "stable-diffusion"
|
||||
model_name = TINY_MODELS[model_type]
|
||||
|
||||
from optimum.onnxruntime import ORTStableDiffusionPipeline
|
||||
|
||||
base = ORTStableDiffusionPipeline.from_pretrained(model_name, export=True)
|
||||
base.save_pretrained(save_directory)
|
||||
|
||||
clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx")
|
||||
optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx")
|
||||
self.verify_clip_optimizer(
|
||||
clip_onnx_path,
|
||||
optimized_clip_onnx_path,
|
||||
expected_counters={
|
||||
"EmbedLayerNormalization": 0,
|
||||
"Attention": 5,
|
||||
"SkipLayerNormalization": 10,
|
||||
"LayerNormalization": 1,
|
||||
"Gelu": 0,
|
||||
"BiasGelu": 0,
|
||||
},
|
||||
float16=True,
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_clip_sdxl(self):
|
||||
save_directory = "tiny-random-stable-diffusion-xl"
|
||||
if os.path.exists(save_directory):
|
||||
shutil.rmtree(save_directory, ignore_errors=True)
|
||||
|
||||
model_type = "stable-diffusion-xl"
|
||||
model_name = TINY_MODELS[model_type]
|
||||
|
||||
from optimum.onnxruntime import ORTStableDiffusionXLPipeline
|
||||
|
||||
base = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True)
|
||||
base.save_pretrained(save_directory)
|
||||
|
||||
clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx")
|
||||
optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx")
|
||||
self.verify_clip_optimizer(
|
||||
clip_onnx_path,
|
||||
optimized_clip_onnx_path,
|
||||
expected_counters={
|
||||
"EmbedLayerNormalization": 0,
|
||||
"Attention": 5,
|
||||
"SkipLayerNormalization": 10,
|
||||
"LayerNormalization": 1,
|
||||
"Gelu": 0,
|
||||
"BiasGelu": 5,
|
||||
},
|
||||
)
|
||||
|
||||
clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "model.onnx")
|
||||
optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "opt.onnx")
|
||||
self.verify_clip_optimizer(
|
||||
clip_onnx_path,
|
||||
optimized_clip_onnx_path,
|
||||
expected_counters={
|
||||
"EmbedLayerNormalization": 0,
|
||||
"Attention": 5,
|
||||
"SkipLayerNormalization": 10,
|
||||
"LayerNormalization": 1,
|
||||
"Gelu": 0,
|
||||
"BiasGelu": 5,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue