Add support for full ViT optimization (#15289)

Add support for ViT optimization in optimizer.py
As ViT architecture follows BERT rather closely, we can easily reuse
BERT fusions for ViT. The only difference is that ViT does not have
attention mask, which means there is no Add node in qk paths.
Make the necessary changes in onnx_exporter.py to be able to cover
optimizations with test.
This commit is contained in:
Anton Korablin 2023-04-04 23:05:24 +02:00 committed by GitHub
parent 1c1d386561
commit 207c57219a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 63 additions and 14 deletions

View file

@ -492,6 +492,23 @@ class FusionAttention(Fusion):
if child.op_type == "LayerNormalization":
root_input = child.output[0]
"""
When Add before the LayerNormalization produces an output
that is consumed by some other nodes other than the LayerNormalization itself,
fused SkipLayerNormalization will have several outputs.
In this case we need to pick the one used in Attention
For example, this is the case for ViT
SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization
| |
| |
+---------------------------------------------------------------------+
"""
parent_node = output_name_to_node[root_input]
if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
root_input = parent_node.output[0]
children = input_name_to_nodes[root_input]
children_types = [child.op_type for child in children]
if children_types.count("MatMul") != 3:
@ -505,11 +522,13 @@ class FusionAttention(Fusion):
is_distill = False
is_distill_add = False
is_no_mask_attention = False
qk_paths = {
"path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
"path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
"path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
"path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
"path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]),
}
qk_nodes = None
@ -521,6 +540,8 @@ class FusionAttention(Fusion):
is_distill = True
if k == "path4":
is_distill_add = True
if k == "path5":
is_no_mask_attention = True
break
if qk_nodes is None:
@ -534,6 +555,8 @@ class FusionAttention(Fusion):
(_, where_qk, matmul_qk, _) = qk_nodes
elif is_distill_add:
(_, add_qk, where_qk, matmul_qk) = qk_nodes
elif is_no_mask_attention:
(_, _, matmul_qk) = qk_nodes
else:
(_, add_qk, _, matmul_qk) = qk_nodes
@ -591,6 +614,8 @@ class FusionAttention(Fusion):
if add_qk_str is None:
logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}")
return
elif is_no_mask_attention:
pass
else:
_, mask_nodes, _ = self.model.match_parent_paths(
add_qk,
@ -603,17 +628,17 @@ class FusionAttention(Fusion):
],
output_name_to_node,
)
if mask_nodes is None:
if not is_no_mask_attention and mask_nodes is None:
logger.debug("fuse_attention: failed to match mask path")
return
if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
if not is_no_mask_attention and len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
_, mul_val = self.model.get_constant_input(mask_nodes[0])
if mul_val != -10000:
self.mask_filter_value = mul_val
if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None
attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv

View file

@ -46,9 +46,11 @@ class FusionOptions:
# Set default to sequence length for BERT model to use fused attention to speed up.
# Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
self.attention_mask_format = (
AttentionMaskFormat.MaskIndexEnd if model_type == "bert" else AttentionMaskFormat.AttentionMask
)
self.attention_mask_format = AttentionMaskFormat.AttentionMask
if model_type == "bert":
self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
elif model_type == "vit":
self.attention_mask_format = AttentionMaskFormat.NoMask
# options for stable diffusion
if model_type in ["unet", "vae", "clip"]:

View file

@ -156,6 +156,7 @@ MODELS = {
False,
"bert",
),
"google/vit-base-patch16-224": (["pixel_values"], 12, False, "vit"),
# "google/pegasus-xsum": (["input_ids"], 11, False, "bert"),
# "google/pegasus-large": (["input_ids"], 11, False, "bert"),
}

View file

@ -16,7 +16,7 @@ from benchmark_helper import OptimizerInfo, Precision, create_onnxruntime_sessio
from huggingface_models import MODEL_CLASSES
from quantize_helper import QuantizeHelper
from torch_onnx_export_helper import torch_onnx_export
from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, LxmertConfig, TransfoXLConfig
sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2"))
from gpt2_helper import PRETRAINED_GPT2_MODELS, GPT2ModelNoPastState, TFGPT2ModelNoPastState # noqa: E402
@ -439,7 +439,7 @@ def validate_and_optimize_onnx(
model_fusion_statistics,
)
return onnx_model_path, is_valid_onnx_model, config.vocab_size
return onnx_model_path, is_valid_onnx_model, None if model_type == "vit" else config.vocab_size
def export_onnx_model_from_pt(
@ -465,12 +465,21 @@ def export_onnx_model_from_pt(
# config, model = load_pt_model_from_tf(model_name)
model.cpu()
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
max_input_size = (
tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024
)
example_inputs = None
max_input_size = None
example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
if model_type == "vit":
image_processor = AutoFeatureExtractor.from_pretrained(model_name, cache_dir=cache_dir)
data = numpy.random.randint(low=0, high=256, size=224 * 224 * 3, dtype=numpy.uint8).reshape(224, 224, 3)
example_inputs = image_processor(data, return_tensors="pt")
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
max_input_size = (
tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024
)
example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
example_inputs = filter_inputs(example_inputs, input_names)
@ -497,7 +506,13 @@ def export_onnx_model_from_pt(
logger.info(f"Exporting ONNX model to {onnx_model_path}")
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
dynamic_axes = None
output_names = None
if model_type == "vit":
dynamic_axes, output_names = {key: {0: "pixel_values"} for key in example_inputs}, ["logits"]
else:
dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
replace_torch_functions()
torch_onnx_export(

View file

@ -56,6 +56,7 @@ MODEL_TYPES = {
"unet": (UnetOnnxModel, "pytorch", 1),
"vae": (VaeOnnxModel, "pytorch", 1),
"clip": (ClipOnnxModel, "pytorch", 1),
"vit": (BertOnnxModel, "pytorch", 1),
}

View file

@ -8,6 +8,7 @@ packaging
transformers >= 4.18.0
scipy
sentencepiece
pillow
# please follow https://pytorch.org/ to install PyTorch for your OS
torch >= 1.13.1

View file

@ -304,6 +304,10 @@ class TestModelOptimization(unittest.TestCase):
def test_huggingface_bart_fusion(self):
self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30])
@pytest.mark.slow
def test_huggingface_vit_fusion(self):
self._test_optimizer_on_huggingface_model("google/vit-base-patch16-224", [0, 11, 0, 0, 12, 1, 24])
@unittest.skipUnless(is_tf_available(), "skip TestBertOptimizationTF since tensorflow is not available")
class TestTensorflowModelOptimization(unittest.TestCase):