From 207c57219abe618bc82ee86ce856dc5abdbc1ad2 Mon Sep 17 00:00:00 2001 From: Anton Korablin <39565513+iAnt0n@users.noreply.github.com> Date: Tue, 4 Apr 2023 23:05:24 +0200 Subject: [PATCH] Add support for full ViT optimization (#15289) Add support for ViT optimization in optimizer.py As ViT architecture follows BERT rather closely, we can easily reuse BERT fusions for ViT. The only difference is that ViT does not have attention mask, which means there is no Add node in qk paths. Make the necessary changes in onnx_exporter.py to be able to cover optimizations with test. --- .../tools/transformers/fusion_attention.py | 31 +++++++++++++++++-- .../tools/transformers/fusion_options.py | 8 +++-- .../tools/transformers/huggingface_models.py | 1 + .../tools/transformers/onnx_exporter.py | 31 ++++++++++++++----- .../python/tools/transformers/optimizer.py | 1 + .../tools/transformers/requirements.txt | 1 + .../python/transformers/test_optimizer.py | 4 +++ 7 files changed, 63 insertions(+), 14 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 093d3c4304..47af945509 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -492,6 +492,23 @@ class FusionAttention(Fusion): if child.op_type == "LayerNormalization": root_input = child.output[0] + """ + When Add before the LayerNormalization produces an output + that is consumed by some other nodes other than the LayerNormalization itself, + fused SkipLayerNormalization will have several outputs. + In this case we need to pick the one used in Attention + + For example, this is the case for ViT + + SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization + | | + | | + +---------------------------------------------------------------------+ + """ + parent_node = output_name_to_node[root_input] + if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4: + root_input = parent_node.output[0] + children = input_name_to_nodes[root_input] children_types = [child.op_type for child in children] if children_types.count("MatMul") != 3: @@ -505,11 +522,13 @@ class FusionAttention(Fusion): is_distill = False is_distill_add = False + is_no_mask_attention = False qk_paths = { "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]), "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]), "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]), "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]), + "path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]), } qk_nodes = None @@ -521,6 +540,8 @@ class FusionAttention(Fusion): is_distill = True if k == "path4": is_distill_add = True + if k == "path5": + is_no_mask_attention = True break if qk_nodes is None: @@ -534,6 +555,8 @@ class FusionAttention(Fusion): (_, where_qk, matmul_qk, _) = qk_nodes elif is_distill_add: (_, add_qk, where_qk, matmul_qk) = qk_nodes + elif is_no_mask_attention: + (_, _, matmul_qk) = qk_nodes else: (_, add_qk, _, matmul_qk) = qk_nodes @@ -591,6 +614,8 @@ class FusionAttention(Fusion): if add_qk_str is None: logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}") return + elif is_no_mask_attention: + pass else: _, mask_nodes, _ = self.model.match_parent_paths( add_qk, @@ -603,17 +628,17 @@ class FusionAttention(Fusion): ], output_name_to_node, ) - if mask_nodes is None: + if not is_no_mask_attention and mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return - if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul": + if not is_no_mask_attention and len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul": _, mul_val = self.model.get_constant_input(mask_nodes[0]) if mul_val != -10000: self.mask_filter_value = mul_val if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input: - mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) + mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index bbc0b731ae..dfcedfe691 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -46,9 +46,11 @@ class FusionOptions: # Set default to sequence length for BERT model to use fused attention to speed up. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. - self.attention_mask_format = ( - AttentionMaskFormat.MaskIndexEnd if model_type == "bert" else AttentionMaskFormat.AttentionMask - ) + self.attention_mask_format = AttentionMaskFormat.AttentionMask + if model_type == "bert": + self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd + elif model_type == "vit": + self.attention_mask_format = AttentionMaskFormat.NoMask # options for stable diffusion if model_type in ["unet", "vae", "clip"]: diff --git a/onnxruntime/python/tools/transformers/huggingface_models.py b/onnxruntime/python/tools/transformers/huggingface_models.py index cdf75efb1e..9f766ccee3 100644 --- a/onnxruntime/python/tools/transformers/huggingface_models.py +++ b/onnxruntime/python/tools/transformers/huggingface_models.py @@ -156,6 +156,7 @@ MODELS = { False, "bert", ), + "google/vit-base-patch16-224": (["pixel_values"], 12, False, "vit"), # "google/pegasus-xsum": (["input_ids"], 11, False, "bert"), # "google/pegasus-large": (["input_ids"], 11, False, "bert"), } diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 801fdb080e..60c2dda992 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -16,7 +16,7 @@ from benchmark_helper import OptimizerInfo, Precision, create_onnxruntime_sessio from huggingface_models import MODEL_CLASSES from quantize_helper import QuantizeHelper from torch_onnx_export_helper import torch_onnx_export -from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig +from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, LxmertConfig, TransfoXLConfig sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2")) from gpt2_helper import PRETRAINED_GPT2_MODELS, GPT2ModelNoPastState, TFGPT2ModelNoPastState # noqa: E402 @@ -439,7 +439,7 @@ def validate_and_optimize_onnx( model_fusion_statistics, ) - return onnx_model_path, is_valid_onnx_model, config.vocab_size + return onnx_model_path, is_valid_onnx_model, None if model_type == "vit" else config.vocab_size def export_onnx_model_from_pt( @@ -465,12 +465,21 @@ def export_onnx_model_from_pt( # config, model = load_pt_model_from_tf(model_name) model.cpu() - tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) + example_inputs = None + max_input_size = None - example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt") + if model_type == "vit": + image_processor = AutoFeatureExtractor.from_pretrained(model_name, cache_dir=cache_dir) + data = numpy.random.randint(low=0, high=256, size=224 * 224 * 3, dtype=numpy.uint8).reshape(224, 224, 3) + + example_inputs = image_processor(data, return_tensors="pt") + else: + tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) + max_input_size = ( + tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + ) + + example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt") example_inputs = filter_inputs(example_inputs, input_names) @@ -497,7 +506,13 @@ def export_onnx_model_from_pt( logger.info(f"Exporting ONNX model to {onnx_model_path}") Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten) + dynamic_axes = None + output_names = None + + if model_type == "vit": + dynamic_axes, output_names = {key: {0: "pixel_values"} for key in example_inputs}, ["logits"] + else: + dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten) replace_torch_functions() torch_onnx_export( diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index a3c16ebcae..99ef58841d 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -56,6 +56,7 @@ MODEL_TYPES = { "unet": (UnetOnnxModel, "pytorch", 1), "vae": (VaeOnnxModel, "pytorch", 1), "clip": (ClipOnnxModel, "pytorch", 1), + "vit": (BertOnnxModel, "pytorch", 1), } diff --git a/onnxruntime/python/tools/transformers/requirements.txt b/onnxruntime/python/tools/transformers/requirements.txt index b8b7dac9e3..ce1380a757 100644 --- a/onnxruntime/python/tools/transformers/requirements.txt +++ b/onnxruntime/python/tools/transformers/requirements.txt @@ -8,6 +8,7 @@ packaging transformers >= 4.18.0 scipy sentencepiece +pillow # please follow https://pytorch.org/ to install PyTorch for your OS torch >= 1.13.1 diff --git a/onnxruntime/test/python/transformers/test_optimizer.py b/onnxruntime/test/python/transformers/test_optimizer.py index 270c015c83..d1fb88c0d0 100644 --- a/onnxruntime/test/python/transformers/test_optimizer.py +++ b/onnxruntime/test/python/transformers/test_optimizer.py @@ -304,6 +304,10 @@ class TestModelOptimization(unittest.TestCase): def test_huggingface_bart_fusion(self): self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30]) + @pytest.mark.slow + def test_huggingface_vit_fusion(self): + self._test_optimizer_on_huggingface_model("google/vit-base-patch16-224", [0, 11, 0, 0, 12, 1, 24]) + @unittest.skipUnless(is_tf_available(), "skip TestBertOptimizationTF since tensorflow is not available") class TestTensorflowModelOptimization(unittest.TestCase):