diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 093d3c4304..47af945509 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -492,6 +492,23 @@ class FusionAttention(Fusion): if child.op_type == "LayerNormalization": root_input = child.output[0] + """ + When Add before the LayerNormalization produces an output + that is consumed by some other nodes other than the LayerNormalization itself, + fused SkipLayerNormalization will have several outputs. + In this case we need to pick the one used in Attention + + For example, this is the case for ViT + + SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization + | | + | | + +---------------------------------------------------------------------+ + """ + parent_node = output_name_to_node[root_input] + if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4: + root_input = parent_node.output[0] + children = input_name_to_nodes[root_input] children_types = [child.op_type for child in children] if children_types.count("MatMul") != 3: @@ -505,11 +522,13 @@ class FusionAttention(Fusion): is_distill = False is_distill_add = False + is_no_mask_attention = False qk_paths = { "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]), "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]), "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]), "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]), + "path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]), } qk_nodes = None @@ -521,6 +540,8 @@ class FusionAttention(Fusion): is_distill = True if k == "path4": is_distill_add = True + if k == "path5": + is_no_mask_attention = True break if qk_nodes is None: @@ -534,6 +555,8 @@ class FusionAttention(Fusion): (_, where_qk, matmul_qk, _) = qk_nodes elif is_distill_add: (_, add_qk, where_qk, matmul_qk) = qk_nodes + elif is_no_mask_attention: + (_, _, matmul_qk) = qk_nodes else: (_, add_qk, _, matmul_qk) = qk_nodes @@ -591,6 +614,8 @@ class FusionAttention(Fusion): if add_qk_str is None: logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}") return + elif is_no_mask_attention: + pass else: _, mask_nodes, _ = self.model.match_parent_paths( add_qk, @@ -603,17 +628,17 @@ class FusionAttention(Fusion): ], output_name_to_node, ) - if mask_nodes is None: + if not is_no_mask_attention and mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return - if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul": + if not is_no_mask_attention and len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul": _, mul_val = self.model.get_constant_input(mask_nodes[0]) if mul_val != -10000: self.mask_filter_value = mul_val if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input: - mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) + mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index bbc0b731ae..dfcedfe691 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -46,9 +46,11 @@ class FusionOptions: # Set default to sequence length for BERT model to use fused attention to speed up. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. - self.attention_mask_format = ( - AttentionMaskFormat.MaskIndexEnd if model_type == "bert" else AttentionMaskFormat.AttentionMask - ) + self.attention_mask_format = AttentionMaskFormat.AttentionMask + if model_type == "bert": + self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd + elif model_type == "vit": + self.attention_mask_format = AttentionMaskFormat.NoMask # options for stable diffusion if model_type in ["unet", "vae", "clip"]: diff --git a/onnxruntime/python/tools/transformers/huggingface_models.py b/onnxruntime/python/tools/transformers/huggingface_models.py index cdf75efb1e..9f766ccee3 100644 --- a/onnxruntime/python/tools/transformers/huggingface_models.py +++ b/onnxruntime/python/tools/transformers/huggingface_models.py @@ -156,6 +156,7 @@ MODELS = { False, "bert", ), + "google/vit-base-patch16-224": (["pixel_values"], 12, False, "vit"), # "google/pegasus-xsum": (["input_ids"], 11, False, "bert"), # "google/pegasus-large": (["input_ids"], 11, False, "bert"), } diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 801fdb080e..60c2dda992 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -16,7 +16,7 @@ from benchmark_helper import OptimizerInfo, Precision, create_onnxruntime_sessio from huggingface_models import MODEL_CLASSES from quantize_helper import QuantizeHelper from torch_onnx_export_helper import torch_onnx_export -from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig +from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, LxmertConfig, TransfoXLConfig sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2")) from gpt2_helper import PRETRAINED_GPT2_MODELS, GPT2ModelNoPastState, TFGPT2ModelNoPastState # noqa: E402 @@ -439,7 +439,7 @@ def validate_and_optimize_onnx( model_fusion_statistics, ) - return onnx_model_path, is_valid_onnx_model, config.vocab_size + return onnx_model_path, is_valid_onnx_model, None if model_type == "vit" else config.vocab_size def export_onnx_model_from_pt( @@ -465,12 +465,21 @@ def export_onnx_model_from_pt( # config, model = load_pt_model_from_tf(model_name) model.cpu() - tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) + example_inputs = None + max_input_size = None - example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt") + if model_type == "vit": + image_processor = AutoFeatureExtractor.from_pretrained(model_name, cache_dir=cache_dir) + data = numpy.random.randint(low=0, high=256, size=224 * 224 * 3, dtype=numpy.uint8).reshape(224, 224, 3) + + example_inputs = image_processor(data, return_tensors="pt") + else: + tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) + max_input_size = ( + tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + ) + + example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt") example_inputs = filter_inputs(example_inputs, input_names) @@ -497,7 +506,13 @@ def export_onnx_model_from_pt( logger.info(f"Exporting ONNX model to {onnx_model_path}") Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten) + dynamic_axes = None + output_names = None + + if model_type == "vit": + dynamic_axes, output_names = {key: {0: "pixel_values"} for key in example_inputs}, ["logits"] + else: + dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten) replace_torch_functions() torch_onnx_export( diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index a3c16ebcae..99ef58841d 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -56,6 +56,7 @@ MODEL_TYPES = { "unet": (UnetOnnxModel, "pytorch", 1), "vae": (VaeOnnxModel, "pytorch", 1), "clip": (ClipOnnxModel, "pytorch", 1), + "vit": (BertOnnxModel, "pytorch", 1), } diff --git a/onnxruntime/python/tools/transformers/requirements.txt b/onnxruntime/python/tools/transformers/requirements.txt index b8b7dac9e3..ce1380a757 100644 --- a/onnxruntime/python/tools/transformers/requirements.txt +++ b/onnxruntime/python/tools/transformers/requirements.txt @@ -8,6 +8,7 @@ packaging transformers >= 4.18.0 scipy sentencepiece +pillow # please follow https://pytorch.org/ to install PyTorch for your OS torch >= 1.13.1 diff --git a/onnxruntime/test/python/transformers/test_optimizer.py b/onnxruntime/test/python/transformers/test_optimizer.py index 270c015c83..d1fb88c0d0 100644 --- a/onnxruntime/test/python/transformers/test_optimizer.py +++ b/onnxruntime/test/python/transformers/test_optimizer.py @@ -304,6 +304,10 @@ class TestModelOptimization(unittest.TestCase): def test_huggingface_bart_fusion(self): self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30]) + @pytest.mark.slow + def test_huggingface_vit_fusion(self): + self._test_optimizer_on_huggingface_model("google/vit-base-patch16-224", [0, 11, 0, 0, 12, 1, 24]) + @unittest.skipUnless(is_tf_available(), "skip TestBertOptimizationTF since tensorflow is not available") class TestTensorflowModelOptimization(unittest.TestCase):