mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Add support for full ViT optimization (#15289)
Add support for ViT optimization in optimizer.py As ViT architecture follows BERT rather closely, we can easily reuse BERT fusions for ViT. The only difference is that ViT does not have attention mask, which means there is no Add node in qk paths. Make the necessary changes in onnx_exporter.py to be able to cover optimizations with test.
This commit is contained in:
parent
1c1d386561
commit
207c57219a
7 changed files with 63 additions and 14 deletions
|
|
@ -492,6 +492,23 @@ class FusionAttention(Fusion):
|
|||
if child.op_type == "LayerNormalization":
|
||||
root_input = child.output[0]
|
||||
|
||||
"""
|
||||
When Add before the LayerNormalization produces an output
|
||||
that is consumed by some other nodes other than the LayerNormalization itself,
|
||||
fused SkipLayerNormalization will have several outputs.
|
||||
In this case we need to pick the one used in Attention
|
||||
|
||||
For example, this is the case for ViT
|
||||
|
||||
SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization
|
||||
| |
|
||||
| |
|
||||
+---------------------------------------------------------------------+
|
||||
"""
|
||||
parent_node = output_name_to_node[root_input]
|
||||
if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
|
||||
root_input = parent_node.output[0]
|
||||
|
||||
children = input_name_to_nodes[root_input]
|
||||
children_types = [child.op_type for child in children]
|
||||
if children_types.count("MatMul") != 3:
|
||||
|
|
@ -505,11 +522,13 @@ class FusionAttention(Fusion):
|
|||
|
||||
is_distill = False
|
||||
is_distill_add = False
|
||||
is_no_mask_attention = False
|
||||
qk_paths = {
|
||||
"path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
|
||||
"path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
|
||||
"path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
|
||||
"path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
|
||||
"path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]),
|
||||
}
|
||||
|
||||
qk_nodes = None
|
||||
|
|
@ -521,6 +540,8 @@ class FusionAttention(Fusion):
|
|||
is_distill = True
|
||||
if k == "path4":
|
||||
is_distill_add = True
|
||||
if k == "path5":
|
||||
is_no_mask_attention = True
|
||||
break
|
||||
|
||||
if qk_nodes is None:
|
||||
|
|
@ -534,6 +555,8 @@ class FusionAttention(Fusion):
|
|||
(_, where_qk, matmul_qk, _) = qk_nodes
|
||||
elif is_distill_add:
|
||||
(_, add_qk, where_qk, matmul_qk) = qk_nodes
|
||||
elif is_no_mask_attention:
|
||||
(_, _, matmul_qk) = qk_nodes
|
||||
else:
|
||||
(_, add_qk, _, matmul_qk) = qk_nodes
|
||||
|
||||
|
|
@ -591,6 +614,8 @@ class FusionAttention(Fusion):
|
|||
if add_qk_str is None:
|
||||
logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}")
|
||||
return
|
||||
elif is_no_mask_attention:
|
||||
pass
|
||||
else:
|
||||
_, mask_nodes, _ = self.model.match_parent_paths(
|
||||
add_qk,
|
||||
|
|
@ -603,17 +628,17 @@ class FusionAttention(Fusion):
|
|||
],
|
||||
output_name_to_node,
|
||||
)
|
||||
if mask_nodes is None:
|
||||
if not is_no_mask_attention and mask_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match mask path")
|
||||
return
|
||||
|
||||
if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
|
||||
if not is_no_mask_attention and len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
|
||||
_, mul_val = self.model.get_constant_input(mask_nodes[0])
|
||||
if mul_val != -10000:
|
||||
self.mask_filter_value = mul_val
|
||||
|
||||
if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
|
||||
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
||||
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None
|
||||
|
||||
attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
|
||||
|
||||
|
|
|
|||
|
|
@ -46,9 +46,11 @@ class FusionOptions:
|
|||
|
||||
# Set default to sequence length for BERT model to use fused attention to speed up.
|
||||
# Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
|
||||
self.attention_mask_format = (
|
||||
AttentionMaskFormat.MaskIndexEnd if model_type == "bert" else AttentionMaskFormat.AttentionMask
|
||||
)
|
||||
self.attention_mask_format = AttentionMaskFormat.AttentionMask
|
||||
if model_type == "bert":
|
||||
self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
|
||||
elif model_type == "vit":
|
||||
self.attention_mask_format = AttentionMaskFormat.NoMask
|
||||
|
||||
# options for stable diffusion
|
||||
if model_type in ["unet", "vae", "clip"]:
|
||||
|
|
|
|||
|
|
@ -156,6 +156,7 @@ MODELS = {
|
|||
False,
|
||||
"bert",
|
||||
),
|
||||
"google/vit-base-patch16-224": (["pixel_values"], 12, False, "vit"),
|
||||
# "google/pegasus-xsum": (["input_ids"], 11, False, "bert"),
|
||||
# "google/pegasus-large": (["input_ids"], 11, False, "bert"),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from benchmark_helper import OptimizerInfo, Precision, create_onnxruntime_sessio
|
|||
from huggingface_models import MODEL_CLASSES
|
||||
from quantize_helper import QuantizeHelper
|
||||
from torch_onnx_export_helper import torch_onnx_export
|
||||
from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig
|
||||
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, LxmertConfig, TransfoXLConfig
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2"))
|
||||
from gpt2_helper import PRETRAINED_GPT2_MODELS, GPT2ModelNoPastState, TFGPT2ModelNoPastState # noqa: E402
|
||||
|
|
@ -439,7 +439,7 @@ def validate_and_optimize_onnx(
|
|||
model_fusion_statistics,
|
||||
)
|
||||
|
||||
return onnx_model_path, is_valid_onnx_model, config.vocab_size
|
||||
return onnx_model_path, is_valid_onnx_model, None if model_type == "vit" else config.vocab_size
|
||||
|
||||
|
||||
def export_onnx_model_from_pt(
|
||||
|
|
@ -465,12 +465,21 @@ def export_onnx_model_from_pt(
|
|||
# config, model = load_pt_model_from_tf(model_name)
|
||||
model.cpu()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
max_input_size = (
|
||||
tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024
|
||||
)
|
||||
example_inputs = None
|
||||
max_input_size = None
|
||||
|
||||
example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
|
||||
if model_type == "vit":
|
||||
image_processor = AutoFeatureExtractor.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
data = numpy.random.randint(low=0, high=256, size=224 * 224 * 3, dtype=numpy.uint8).reshape(224, 224, 3)
|
||||
|
||||
example_inputs = image_processor(data, return_tensors="pt")
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
max_input_size = (
|
||||
tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024
|
||||
)
|
||||
|
||||
example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
|
||||
|
||||
example_inputs = filter_inputs(example_inputs, input_names)
|
||||
|
||||
|
|
@ -497,7 +506,13 @@ def export_onnx_model_from_pt(
|
|||
logger.info(f"Exporting ONNX model to {onnx_model_path}")
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
|
||||
dynamic_axes = None
|
||||
output_names = None
|
||||
|
||||
if model_type == "vit":
|
||||
dynamic_axes, output_names = {key: {0: "pixel_values"} for key in example_inputs}, ["logits"]
|
||||
else:
|
||||
dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
|
||||
|
||||
replace_torch_functions()
|
||||
torch_onnx_export(
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ MODEL_TYPES = {
|
|||
"unet": (UnetOnnxModel, "pytorch", 1),
|
||||
"vae": (VaeOnnxModel, "pytorch", 1),
|
||||
"clip": (ClipOnnxModel, "pytorch", 1),
|
||||
"vit": (BertOnnxModel, "pytorch", 1),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ packaging
|
|||
transformers >= 4.18.0
|
||||
scipy
|
||||
sentencepiece
|
||||
pillow
|
||||
|
||||
# please follow https://pytorch.org/ to install PyTorch for your OS
|
||||
torch >= 1.13.1
|
||||
|
|
|
|||
|
|
@ -304,6 +304,10 @@ class TestModelOptimization(unittest.TestCase):
|
|||
def test_huggingface_bart_fusion(self):
|
||||
self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30])
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_huggingface_vit_fusion(self):
|
||||
self._test_optimizer_on_huggingface_model("google/vit-base-patch16-224", [0, 11, 0, 0, 12, 1, 24])
|
||||
|
||||
|
||||
@unittest.skipUnless(is_tf_available(), "skip TestBertOptimizationTF since tensorflow is not available")
|
||||
class TestTensorflowModelOptimization(unittest.TestCase):
|
||||
|
|
|
|||
Loading…
Reference in a new issue