diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 20f7a52113..3cccb40c7e 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -65,7 +65,7 @@ if "OMP_NUM_THREADS" not in os.environ: os.environ["OMP_NUM_THREADS"] = str(cpu_count) import torch -from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model) +from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model, LxmertConfig) def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, batch_sizes, sequence_lengths, @@ -128,8 +128,7 @@ def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, b input_value_type = numpy.int64 if 'pt' in model_source else numpy.int32 ort_inputs = create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, - input_value_type) - + config, input_value_type) result_template = { "engine": "onnxruntime", "version": onnxruntime.__version__, @@ -334,8 +333,18 @@ def run_tensorflow(use_gpu, model_names, model_class, precision, num_threads, ba @run_with_tf_optimizations(do_eager_mode=False, use_xla=False) def encoder_decoder_forward(): return model(input_ids, decoder_input_ids=input_ids, training=False) + + @run_with_tf_optimizations(do_eager_mode=False, use_xla=False) + def lxmert_forward(): + feats = tf.random.normal([1, 1, config.visual_feat_dim]) + pos = tf.random.normal([1, 1, config.visual_pos_dim]) + return model(input_ids, visual_feats=feats, visual_pos=pos, training=False) - inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward + inference = encoder_forward + if config.is_encoder_decoder: + inference = encoder_decoder_forward + elif isinstance(config, LxmertConfig): + inference = lxmert_forward inference() diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 7fe05df333..46a5f0dd90 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -30,7 +30,10 @@ class Precision(Enum): def __str__(self): return self.value - +IO_BINDING_DATA_TYPE_MAP = { + "float32": numpy.float32, + # TODO: Add more. +} def create_onnxruntime_session(onnx_model_path, use_gpu, enable_all_optimization=True, @@ -214,7 +217,8 @@ def inference_ort_with_io_binding(ort_session, # Bind inputs to device for name in ort_inputs.keys(): np_input = torch.from_numpy(ort_inputs[name]).to(device) - io_binding.bind_input(name, np_input.device.type, 0, data_type, np_input.shape, np_input.data_ptr()) + input_type = IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)] if str(ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP else data_type + io_binding.bind_input(name, np_input.device.type, 0, input_type, np_input.shape, np_input.data_ptr()) # Bind outputs buffers with the sizes needed if not allocated already if len(output_buffers) == 0: allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 0024e9d15f..7b866fc6bc 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -166,31 +166,27 @@ class FusionAttention(Fusion): # Check if all matrices have the same shape assert qw.shape == kw.shape == vw.shape - if len(qw.shape) != 2: - logger.debug(f"weights for Q is expected to be 2D.") - return None + # All the matrices have the same shape. For 2d weights, the shapes would be [in_size, out_size]. + # For 3d weights, shape would be [in_size, a, b] where a*b = out_size + in_size = qw.shape[0] + out_size = np.prod(qw.shape[1:]) - # All the matrices have the same shape (in_size, out_size) - in_size, out_size = qw.shape + qkv_weight = np.stack((qw, kw, vw), axis=1) + + qb = numpy_helper.to_array(q_bias) + kb = numpy_helper.to_array(k_bias) + vb = numpy_helper.to_array(v_bias) + + # 1d bias shape: [outsize,]. 2d bias shape: [a, b] where a*b = out_size + assert qb.shape == kb.shape == vb.shape + assert np.prod(qb.shape) == out_size if out_size != hidden_size: logger.debug( f"Shape for weights of Q is {in_size, out_size}, which does not match hidden_size={hidden_size}") return None - qkv_weight = np.stack((qw, kw, vw), axis=-2) - - qb = numpy_helper.to_array(q_bias) - assert qb.shape == (out_size, ) - - kb = numpy_helper.to_array(k_bias) - assert kb.shape == (out_size, ) - - vb = numpy_helper.to_array(v_bias) - assert vb.shape == (out_size, ) - - qkv_bias = np.stack((qb, kb, vb), axis=-2) - + qkv_bias = np.stack((qb, kb, vb), axis=0) attention_node_name = self.model.create_node_name('Attention') weight = helper.make_tensor(name=attention_node_name + '_qkv_weight', diff --git a/onnxruntime/python/tools/transformers/gpt2_helper.py b/onnxruntime/python/tools/transformers/gpt2_helper.py index 8079d6277b..83142d5e54 100644 --- a/onnxruntime/python/tools/transformers/gpt2_helper.py +++ b/onnxruntime/python/tools/transformers/gpt2_helper.py @@ -14,7 +14,7 @@ import time import re from pathlib import Path from typing import List, Dict, Tuple, Union -from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config +from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config, TFGPT2Model from benchmark_helper import Precision logger = logging.getLogger(__name__) @@ -33,6 +33,15 @@ class GPT2ModelNoPastState(GPT2Model): def forward(self, input_ids): return super().forward(input_ids, use_cache=False, return_dict=False) +class TFGPT2ModelNoPastState(TFGPT2Model): + """ Here we wrap a class to disable past state output. + """ + def __init__(self, config): + config.use_cache = False + super().__init__(config) + + def forward(self, input_ids): + return super().call(input_ids, use_cache=False) class MyGPT2Model(GPT2Model): """ Here we wrap a class for Onnx model conversion for GPT2Model with past state. diff --git a/onnxruntime/python/tools/transformers/huggingface_models.py b/onnxruntime/python/tools/transformers/huggingface_models.py index 6cbdf194ea..a81cdb82da 100644 --- a/onnxruntime/python/tools/transformers/huggingface_models.py +++ b/onnxruntime/python/tools/transformers/huggingface_models.py @@ -13,13 +13,24 @@ MODEL_CLASSES = [ # Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type MODELS = { # BERT - "bert-base-uncased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"), - "bert-large-uncased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"), - "bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"), - "bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask", - "token_type_ids"], 11, False, "bert"), - "bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"), - + "bert-base-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + "bert-large-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + "bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-base-multilingual-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-base-multilingual-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-base-chinese": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-base-german-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-large-uncased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-large-cased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask", + # "token_type_ids"], 12, False, "bert"), + # "bert-large-cased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask", + # "token_type_ids"], 12, False, "bert"), + # "bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-base-german-dbmdz-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # "bert-base-german-dbmdz-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + # todo: more models to add # GPT (no past state) "openai-gpt": (["input_ids"], 11, False, "gpt2"), # GPT-2 (no past state, use benchmark_gpt2.py for past_key_values) @@ -29,7 +40,8 @@ MODELS = { "gpt2-xl": (["input_ids"], 11, True, "gpt2"), "distilgpt2": (["input_ids"], 11, False, "gpt2"), # Transformer-XL - #"transfo-xl-wt103": (["input_ids"], 11, False, "bert"), + "transfo-xl-wt103": + (["input_ids", "mems"], 12, False, "bert"), # Models uses Einsum, which need opset version 12 and PyTorch 1.5.0 or above. # XLNet "xlnet-base-cased": (["input_ids"], 12, False, "bert"), "xlnet-large-cased": (["input_ids"], 12, False, "bert"), @@ -37,14 +49,12 @@ MODELS = { "xlm-mlm-en-2048": (["input_ids"], 11, True, "bert"), "xlm-mlm-ende-1024": (["input_ids"], 11, False, "bert"), "xlm-mlm-enfr-1024": (["input_ids"], 11, False, "bert"), - # XML Roberta - "xlm-roberta-base": (["input_ids"], 12, False, "bert"), # RoBERTa - "roberta-base": (["input_ids", "attention_mask"], 11, False, "bert"), - "roberta-large": (["input_ids", "attention_mask"], 11, False, "bert"), - "roberta-large-mnli": (["input_ids", "attention_mask"], 11, False, "bert"), + "roberta-base": (["input_ids", "attention_mask"], 12, False, "bert"), + "roberta-large": (["input_ids", "attention_mask"], 12, False, "bert"), + "roberta-large-mnli": (["input_ids", "attention_mask"], 12, False, "bert"), "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 11, False, "bert"), - "distilroberta-base": (["input_ids", "attention_mask"], 11, False, "bert"), + "distilroberta-base": (["input_ids", "attention_mask"], 12, False, "bert"), # DistilBERT "distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"), @@ -63,11 +73,11 @@ MODELS = { "albert-xlarge-v2": (["input_ids"], 12, True, "bert"), #"albert-xxlarge-v2": (["input_ids"], 12, True, "bert"), # T5 (use benchmark_t5.py instead) - #"t5-small": (["input_ids"], 12, False, "bert"), - #"t5-base": (["input_ids"], 12, False, "bert"), - #"t5-large": (["input_ids"], 12, True, "bert"), - #"t5-3b": (["input_ids"], 12, True, "bert"), - #"t5-11b": (["input_ids"], 12, True, "bert"), + # "t5-small": (["input_ids", "decoder_input_ids"], 12, False, "bert"), + # "t5-base": (["input_ids", "decoder_input_ids"], 12, False, "bert"), + # "t5-large": (["input_ids", "decoder_input_ids"], 12, True, "bert"), + # "t5-3b": (["input_ids", "decoder_input_ids"], 12, True, "bert"), + # "t5-11b": (["input_ids", "decoder_input_ids"], 12, True, "bert"), #"valhalla/t5-small-qa-qg-hl": (["input_ids"], 12, True, "bert"), # XLM-RoBERTa "xlm-roberta-base": (["input_ids"], 11, False, "bert"), @@ -98,6 +108,20 @@ MODELS = { # MBart "facebook/mbart-large-cc25": (["input_ids"], 11, True, "bert"), "facebook/mbart-large-en-ro": (["input_ids"], 11, True, "bert"), + # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"), + # # Longformer + # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"), + # "allenai/longformer-large-4096": (["input_ids"], 12, True, "bert"), + # "funnel-transformer/small": (["input_ids"], 12, False, "bert"), + # "funnel-transformer/small-base": (["input_ids"], 12, False, "bert"), + # "funnel-transformer/medium": (["input_ids"], 12, False, "bert"), + # "funnel-transformer/medium-base": (["input_ids"], 12, False, "bert"), + # "funnel-transformer/intermediate": (["input_ids"], 12, False, "bert"), + # "funnel-transformer/intermediate-base": (["input_ids"], 12, False, "bert"), + # "funnel-transformer/large": (["input_ids"], 12, True, "bert"), + # "funnel-transformer/large-base": (["input_ids"], 12, True, "bert"), + # "funnel-transformer/xlarge": (["input_ids"], 12, True, "bert"), + # "funnel-transformer/xlarge-base": (["input_ids"], 12, True, "bert"), # Layoutlm "microsoft/layoutlm-base-uncased": (["input_ids"], 11, False, "bert"), "microsoft/layoutlm-large-uncased": (["input_ids"], 11, False, "bert"), @@ -105,4 +129,7 @@ MODELS = { "squeezebert/squeezebert-uncased": (["input_ids"], 11, False, "bert"), "squeezebert/squeezebert-mnli": (["input_ids"], 11, False, "bert"), "squeezebert/squeezebert-mnli-headless": (["input_ids"], 11, False, "bert"), + "unc-nlp/lxmert-base-uncased": (["input_ids", "visual_feats", "visual_pos"], 11, False, "bert"), + # "google/pegasus-xsum": (["input_ids"], 11, False, "bert"), + # "google/pegasus-large": (["input_ids"], 11, False, "bert"), } diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 71d08705fd..1aaa5144f2 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -9,11 +9,12 @@ import numpy import os import torch from pathlib import Path -from transformers import AutoConfig, AutoTokenizer, AutoModel +from transformers import AutoConfig, AutoTokenizer, AutoModel, LxmertConfig, TransfoXLConfig from benchmark_helper import create_onnxruntime_session, Precision -from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS +from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState from quantize_helper import QuantizeHelper from huggingface_models import MODEL_CLASSES +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' logger = logging.getLogger(__name__) @@ -40,7 +41,7 @@ def restore_torch_functions(): torch.triu = torch_func["triu"] -def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, data_type=numpy.int64): +def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, data_type=numpy.int64): input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type) inputs = {'input_ids': input_ids} @@ -53,13 +54,23 @@ def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_name segment_ids = numpy.zeros([batch_size, sequence_length], dtype=data_type) inputs['token_type_ids'] = segment_ids + if config.is_encoder_decoder: + inputs['decoder_input_ids'] = input_ids + + if isinstance(config, LxmertConfig): + inputs["visual_feats"] = numpy.random.randn(1, 1, config.visual_feat_dim).astype(numpy.float32) + inputs["visual_pos"] = numpy.random.randn(1, 1, config.visual_pos_dim).astype(numpy.float32) + if isinstance(config, TransfoXLConfig): + inputs["tf_transfo_xl_model/transformer/pos_emb/einsum/Einsum/inputs_1:0"] = numpy.zeros([config.hidden_size], + dtype=numpy.float32) return inputs def filter_inputs(inputs, input_names): remaining_model_inputs = {} for input_name in input_names: - remaining_model_inputs[input_name] = inputs[input_name] + if input_name in inputs: + remaining_model_inputs[input_name] = inputs[input_name] return remaining_model_inputs @@ -88,7 +99,7 @@ def build_dynamic_axes(example_inputs, outputs_flatten): return dynamic_axes, output_names -def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, fp16): +def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, fp16, output_names=None): test_session = create_onnxruntime_session(onnx_model_path, use_gpu, enable_all_optimization=False) if test_session is None: logger.error(f"{onnx_model_path} is an invalid ONNX model") @@ -97,8 +108,8 @@ def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten logger.info(f"{onnx_model_path} is a valid ONNX model") # Compare the inference result with PyTorch or Tensorflow - example_ort_inputs = {k: t.cpu().numpy() for k, t in example_inputs.items()} - example_ort_outputs = test_session.run(None, example_ort_inputs) + example_ort_inputs = {k: t.numpy() for k, t in example_inputs.items()} + example_ort_outputs = test_session.run(output_names, example_ort_inputs) if len(example_outputs_flatten) != len(example_ort_outputs): logger.error( f"Number of output tensors expected {len(example_outputs_flatten)}, got {len(example_ort_outputs)}") @@ -111,7 +122,7 @@ def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten rtol = 5e-02 if fp16 else 1e-4 atol = 1e-01 if fp16 else 1e-4 - if not numpy.allclose(example_ort_outputs[i], example_outputs_flatten[i].cpu(), rtol=rtol, atol=atol): + if not numpy.allclose(example_ort_outputs[i], example_outputs_flatten[i].cpu().numpy(), rtol=rtol, atol=atol): logger.error(f"Output tensor {i} is not close: rtol={rtol}, atol={atol}") return False @@ -195,7 +206,7 @@ def optimize_onnx_model(model_name, onnx_model_path, optimized_model_path, model optimization_options=optimization_options, use_gpu=use_gpu, only_onnxruntime=False) - if model_type == 'bert_keras': + if model_type == 'bert_keras' or model_type == "bert_tf": opt_model.use_dynamic_axes() model_fusion_statistics[optimized_model_path] = opt_model.get_fused_operator_statistics() @@ -234,7 +245,7 @@ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_ if model_class_name == "GPT2ModelNoPastState": if is_tf_model: - raise NotImplementedError("TFGPT2ModelNoPastState is currently not supported.") + return TFGPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir) else: return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir) @@ -279,13 +290,27 @@ def load_pt_model_from_tf(model_name): return config, model -def validate_and_optimize_onnx(model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, - precision, optimize_onnx, validate_onnx, use_raw_attention_mask, overwrite, config, - model_fusion_statistics, onnx_model_path, example_inputs, example_outputs_flatten): +def validate_and_optimize_onnx(model_name, + use_external_data_format, + model_type, + onnx_dir, + input_names, + use_gpu, + precision, + optimize_onnx, + validate_onnx, + use_raw_attention_mask, + overwrite, + config, + model_fusion_statistics, + onnx_model_path, + example_inputs, + example_outputs_flatten, + output_names=None): is_valid_onnx_model = True if validate_onnx: is_valid_onnx_model = validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, - False) + False, output_names) if optimize_onnx or precision == Precision.FLOAT16 or precision == Precision.INT8: # Use script (optimizer.py) to optimize optimized_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), True, use_gpu, precision, @@ -297,7 +322,7 @@ def validate_and_optimize_onnx(model_name, use_external_data_format, model_type, onnx_model_path = optimized_model_path if validate_onnx: is_valid_onnx_model = validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, - precision == Precision.FLOAT16) + precision == Precision.FLOAT16, output_names) if precision == Precision.INT8: logger.info(f"Quantizing model: {onnx_model_path}") @@ -375,46 +400,88 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma import tensorflow as tf tf.config.set_visible_devices([], 'GPU') - config, model = load_tf_model(model_name, model_class, cache_dir) - - model._saved_model_inputs_spec = None - tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) + # Fix "Using pad_token, but it is not set yet" error. + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) max_input_size = tokenizer.max_model_input_sizes[ model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + config, model = load_tf_model(model_name, model_class, cache_dir) + model.resize_token_embeddings(len(tokenizer)) + example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="tf", max_length=max_input_size, - pad_to_max_length=True, + padding="max_length", truncation=True) - example_inputs = filter_inputs(example_inputs, input_names) - example_outputs = model(example_inputs, training=False).to_tuple() + if config.is_encoder_decoder: + example_inputs["decoder_input_ids"] = tokenizer.encode_plus("This is a sample input", + return_tensors="tf", + max_length=max_input_size, + padding="max_length", + truncation=True).input_ids + if model_name == "unc-nlp/lxmert-base-uncased": + example_inputs["visual_feats"] = tf.random.normal([1, 1, config.visual_feat_dim]) + example_inputs["visual_pos"] = tf.random.normal([1, 1, config.visual_pos_dim]) - # Flatten is needed for gpt2 and distilgpt2. - example_outputs_flatten = flatten(example_outputs) - example_outputs_flatten = update_flatten_list(example_outputs_flatten, []) + try: + # Use no past state for these models + if config.use_cache: + config.use_cache = False + except: + pass + + example_outputs = model(example_inputs, training=False) + output_names = None + + # For xlnet models, only compare the last_hidden_state output. + if model_name == "xlnet-base-cased" or model_name == "xlnet-large-cased": + output_names = ["last_hidden_state"] + example_outputs = example_outputs["last_hidden_state"] + + # Flatten is needed for gpt2 and distilgpt2. Output name sorting is needed for tf2onnx outputs to match onnx outputs. + from tensorflow.python.util import nest + example_outputs_flatten = nest.flatten(example_outputs) onnx_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), False, use_gpu, precision, False, use_external_data_format) + tf_internal_model_path = onnx_model_path[:-5] if use_external_data_format else onnx_model_path - if overwrite or not os.path.exists(onnx_model_path): + if overwrite or not os.path.exists(tf_internal_model_path): logger.info("Exporting ONNX model to {}".format(onnx_model_path)) - Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) + if not use_external_data_format: + Path(tf_internal_model_path).parent.mkdir(parents=True, exist_ok=True) + + import tf2onnx, zipfile + tf2onnx.logging.set_level(tf2onnx.logging.ERROR) + specs = [] + for name, value in example_inputs.items(): + dims = [None] * len(value.shape) + specs.append(tf.TensorSpec(tuple(dims), value.dtype, name=name)) + _, _ = tf2onnx.convert.from_keras(model, + input_signature=tuple(specs), + opset=opset_version, + large_model=use_external_data_format, + output_path=tf_internal_model_path) + if use_external_data_format: + # need to unpack the zip for run_onnxruntime() + with zipfile.ZipFile(tf_internal_model_path, 'r') as z: + z.extractall(os.path.dirname(tf_internal_model_path)) + tf_internal_model_path = os.path.join(os.path.dirname(tf_internal_model_path), "__MODEL_PROTO.onnx") + if os.path.exists(onnx_model_path): + os.remove(onnx_model_path) + os.rename(tf_internal_model_path, onnx_model_path) - import keras2onnx - onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=opset_version) - keras2onnx.save_model(onnx_model, onnx_model_path) else: logger.info(f"Skip export since model existed: {onnx_model_path}") - model_type = model_type + '_keras' - + model_type = model_type + '_tf' onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx( model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path, - example_inputs, example_outputs_flatten) + example_inputs, example_outputs_flatten, output_names) return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py index 3aa7d30db2..6c29585c41 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py @@ -138,7 +138,8 @@ class BertOnnxModelKeras(BertOnnxModelTF): mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) logger.debug("Create an Attention node.") attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, - add_q, add_k, add_v, parent.output[0], + add_q, add_k, add_v, self.num_heads, + self.hidden_size, parent.output[0], reshape_qkv.output[0]) if attention_node is None: continue diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py index 0d88eee3ef..2d63dc7c0a 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py @@ -29,6 +29,21 @@ class BertOnnxModelTF(BertOnnxModel): self.remove_nodes(nodes_to_remove) logger.info(f"Removed Identity count: {len(nodes_to_remove)}") + def match_mask_path(self, add_or_sub_before_softmax): + mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Reshape', 'Cast'], + [1, None, 1, 0]) + if mask_nodes is not None: + return mask_nodes + + mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Slice', 'Unsqueeze'], + [1, 0, 1, 0, 0]) + if mask_nodes is not None: + return mask_nodes + + mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'], + [1, None, 1, 0, 0]) + return mask_nodes + def fuse_mask(self): nodes_to_remove = [] for node in self.nodes(): @@ -85,16 +100,14 @@ class BertOnnxModelTF(BertOnnxModel): nodes_to_remove = [] for node in self.nodes(): if node.op_type == 'Mul' and self.has_constant_input(node, -10000): - mask_path = self.match_parent_path(node, ['Sub', 'Unsqueeze', 'Mul', 'Cast', 'Reshape', 'Cast'], - [0, 1, 0, 1, 0, 0]) + mask_path = self.match_parent_path(node, ['Sub', 'Cast', 'Slice', 'Unsqueeze'], [0, 1, 0, 0]) if mask_path is None: continue - sub_node, unsqueeze_node, mul_node, cast_node_0, reshape_node_0, cast_node_1 = mask_path + sub_node, cast_node, slice_node, unsqueeze_node = mask_path mask_input_name = self.attention_mask.get_first_mask() - - if cast_node_1.input[0] != mask_input_name: - print("Cast input {} is not mask input{}".format(cast_node_1.input[0], mask_input_name)) + if unsqueeze_node.input[0] != mask_input_name: + print("Cast input {} is not mask input {}".format(unsqueeze_node.input[0], mask_input_name)) continue unsqueeze_added_1 = onnx.helper.make_node('Unsqueeze', @@ -109,13 +122,14 @@ class BertOnnxModelTF(BertOnnxModel): name='Mask_UnSqueeze_2', axes=[2]) + #self.replace_node_input(cast_node, cast_node.input[0], 'mask_fuse_unsqueeze2_output') cast_node_2 = onnx.helper.make_node('Cast', inputs=['mask_fuse_unsqueeze2_output'], outputs=['mask_fuse_cast_output']) cast_node_2.attribute.extend([onnx.helper.make_attribute("to", 1)]) self.replace_node_input(sub_node, sub_node.input[1], 'mask_fuse_cast_output') - nodes_to_remove.extend([unsqueeze_node, mul_node, cast_node_0, reshape_node_0, cast_node_1]) + nodes_to_remove.extend([slice_node, unsqueeze_node, cast_node]) self.add_node(unsqueeze_added_1) self.add_node(unsqueeze_added_2) self.add_node(cast_node_2) @@ -360,9 +374,22 @@ class BertOnnxModelTF(BertOnnxModel): nodes_to_remove = [] attention_count = 0 + start_nodes = [] skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization") - for normalize_node in skip_layer_norm_nodes: + layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization") + # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm + # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern + start_nodes.extend(skip_layer_norm_nodes) + start_nodes.extend(layer_norm_nodes) + + for normalize_node in start_nodes: # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + if normalize_node.op_type == 'LayerNormalization': + add_before_layernorm = self.match_parent(normalize_node, 'Add', 0) + if add_before_layernorm is not None: + normalize_node = add_before_layernorm + else: + continue parent = self.get_parent(normalize_node, 1) if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]: parent = self.get_parent(normalize_node, 0) @@ -376,50 +403,63 @@ class BertOnnxModelTF(BertOnnxModel): qkv_nodes = self.match_parent_path(normalize_node, ['MatMul', 'Reshape', 'Transpose', 'MatMul'], [1, 0, 0, 0]) if qkv_nodes is None: - logger.debug("Failed to match qkv nodes") - continue + qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'Einsum', 'Einsum'], [0, 0, 0]) + if qkv_nodes is None: + logger.debug("Failed to match qkv nodes") + continue - (reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes[-3:] + matmul_qkv = qkv_nodes[-1] v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0]) if v_nodes is None: - logger.debug("Failed to match v path") - continue + v_nodes = self.match_parent_path(matmul_qkv, ['Add', 'Einsum'], [1, 0]) + if v_nodes is None: + logger.debug("Failed to match v path") + continue - (transpose_v, reshape_v, add_v, matmul_v) = v_nodes + add_v = v_nodes[-2] + matmul_v = v_nodes[-1] qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', "Mul", 'MatMul'], [0, 0, 0, 0]) if qk_nodes is None: - logger.debug("Failed to match qk_paths") - continue - (softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes + qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'Einsum'], [0, 0, 0]) + if qk_nodes is None: + logger.debug("Failed to match qk_paths") + continue + matmul_qk = qk_nodes[-1] q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, 0]) if q_nodes is None: - logger.debug("Failed to match q path") - continue - (transpose_q, reshape_q, add_q, matmul_q) = q_nodes + q_nodes = self.match_parent_path(matmul_qk, ['Add', 'Einsum'], [0, 0]) + if q_nodes is None: + logger.debug("Failed to match q path") + continue + add_q = q_nodes[-2] + matmul_q = q_nodes[-1] k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0]) if k_nodes is None: - logger.debug("Failed to match k path") - continue - (transpose_k, reshape_k, add_k, matmul_k) = k_nodes - - mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Unsqueeze'], [1, 0, 1]) - if mask_nodes is None: - mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Mul'], [1, 0, 1, 0, 0]) - if mask_nodes is None: - logger.debug("Failed to match mask path") + k_nodes = self.match_parent_path(matmul_qk, ['Mul', 'Add', 'Einsum'], [1, 0, 0]) + if k_nodes is None: + logger.debug("Failed to match k path") continue + add_k = k_nodes[-2] + matmul_k = k_nodes[-1] + + mask_nodes = self.match_mask_path(qk_nodes[1]) + + if mask_nodes is None: + logger.debug("Cannot find mask_nodes.") + continue if not self.has_constant_input(mask_nodes[1], 1): logger.debug("Sub node expected to have an input with constant value 1.0.") continue # add a squeeze node to convert a 3-d mask to 2-d - squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0]) + squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0]) or self.match_parent_path( + mask_nodes[-1], ['Expand'], [0]) squeeze_node_name = "Squeeze_3d_to_2d_mask" squeeze_output_name = squeeze_node_name + "_output" - if squeeze_node is None and len(mask_nodes) == 5: + if squeeze_node is None and len(mask_nodes) == 5 and self.find_graph_input(mask_nodes[-1].input[0]) is None: mask_input = mask_nodes[-1].input[1] self.add_node( helper.make_node("Squeeze", [mask_input], [squeeze_output_name], squeeze_node_name, axes=[1])) @@ -427,11 +467,32 @@ class BertOnnxModelTF(BertOnnxModel): is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node) if is_same_root: - mask_index = self.attention_mask.process_mask(squeeze_output_name) + mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) logger.debug("Create an Attention node.") - attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, - add_q, add_k, add_v, parent.output[0], - reshape_qkv.output[0]) + # For tf models, q and v are flipped. + attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_k, matmul_q, matmul_v, + add_k, add_q, add_v, self.num_heads, + self.hidden_size, parent.output[0], + qkv_nodes[2].output[0]) + if attention_node is None: + continue + + if qkv_nodes[1].op_type == 'Einsum': + # add reshape before einsum + tensor = helper.make_tensor(name=qkv_nodes[1].name + "_newshape", + data_type=TensorProto.INT64, + dims=[4], + vals=np.int64( + [[0, 0, self.num_heads, + int(self.hidden_size / self.num_heads)]]).tobytes(), + raw=True) + self.add_initializer(tensor) + reshape_ = helper.make_node("Reshape", + inputs=[attention_node.output[0], qkv_nodes[1].name + "_newshape"], + outputs=[qkv_nodes[1].name + "_reshape_output"], + name=qkv_nodes[1].name + "_reshape") + qkv_nodes[1].input[0] = qkv_nodes[1].name + "_reshape_output" + self.add_node(reshape_) if parent.op_type == 'Reshape': # Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1] @@ -443,13 +504,11 @@ class BertOnnxModelTF(BertOnnxModel): self.add_initializer(tensor) parent.input[1] = parent.name + "_modified" - if attention_node is None: - continue - + self.add_node(attention_node) attention_count += 1 - nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv]) + nodes_to_remove.extend(qkv_nodes[2:]) nodes_to_remove.extend(qk_nodes) nodes_to_remove.extend(q_nodes) nodes_to_remove.extend(k_nodes) @@ -466,7 +525,20 @@ class BertOnnxModelTF(BertOnnxModel): self.remove_identity() self.process_embedding() #TODO: remove fuse mask since we have embedding fused so fuse_attention shall handle the mask nodes. - self.fuse_mask() + # self.fuse_mask() + self.skip_reshape() + + def skip_reshape(self): + count = 0 + reshape_nodes = self.get_nodes_by_op_type("Reshape") + for reshape_node in reshape_nodes: + parent = self.get_parent(reshape_node, 0) + if parent is not None and parent.op_type == "Reshape": + reshape_node.input[0] = parent.input[0] + count += 1 + + if count > 0: + logger.info(f"Skip consequent Reshape count: {count}") def remove_reshape_before_first_attention(self): attention_nodes = self.get_nodes_by_op_type("Attention") @@ -475,7 +547,7 @@ class BertOnnxModelTF(BertOnnxModel): if path is None: continue logger.info("Remove Reshape before first Attention node.") - reshape, embed = path + reshape, _ = path self.replace_input_of_all_nodes(reshape.output[0], reshape.input[0]) self.remove_node(reshape) break diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index b46fc5de75..d120dfe7dd 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -39,7 +39,8 @@ MODEL_CLASSES = { "bert": (BertOnnxModel, "pytorch", True), "bert_tf": (BertOnnxModelTF, "tf2onnx", False), "bert_keras": (BertOnnxModelKeras, "keras2onnx", False), - "gpt2": (Gpt2OnnxModel, "pytorch", True) + "gpt2": (Gpt2OnnxModel, "pytorch", True), + "gpt2_tf": (Gpt2OnnxModel, 'tf2onnx', False) # might add a class for GPT2OnnxModel for TF later. } diff --git a/onnxruntime/python/tools/transformers/test/bert_model_generator.py b/onnxruntime/python/tools/transformers/test/bert_model_generator.py index 75d5e4796f..5d1a65f281 100644 --- a/onnxruntime/python/tools/transformers/test/bert_model_generator.py +++ b/onnxruntime/python/tools/transformers/test/bert_model_generator.py @@ -127,7 +127,107 @@ def create_bert_attention(input_hidden_size=16, pruned_num_heads=2, pruned_head_ model = helper.make_model(graph) return model +def create_tf2onnx_attention_3d(input_hidden_size=16, num_heads=4, head_size=4, use_float_mask=False): + # unsqueeze in opset version 13 has two inputs (axis is moved from attribute to input). + has_unsqueeze_two_inputs = (version.parse(onnx.__version__) >= version.parse('1.8.0')) + + # nodes in attention subgraph + nodes = [ + helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm"), + helper.make_node("LayerNormalization", ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752), + + # q nodes + helper.make_node("Einsum", ["layernorm_out", "einsum_q_weight"], ["einsum_q_out"], "einsum_q", equation="abc,cde->abde"), + helper.make_node("Add", ["einsum_q_out", "add_q_weight"], ["add_q_out"], "add_q"), + + # k nodes + helper.make_node("Einsum", ["layernorm_out", "einsum_k_weight"], ["einsum_k_out"], "einsum_k", equation="abc,cde->abde"), + helper.make_node("Add", ["einsum_k_out", "add_k_weight"], ["add_k_out"], "add_k"), + helper.make_node("Mul", ["add_k_out", "mul_weight_1"], ["mul_k_out"], "mul_k"), + + # mask nodes + helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0") if has_unsqueeze_two_inputs \ + else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1, 2]), + helper.make_node("Slice", ["unsqueeze0_out", "slice_start", "slice_end", "slice_axes", "slice_steps"], ["slice_out"], "slice"), + + # when attention_mask is float type, no need to cast + helper.make_node("Cast", ["slice_out"], ["cast_out"], "cast", to=1) if not use_float_mask else None, + helper.make_node("Sub", ["sub_weight", "unsqueeze1_out" if use_float_mask else "cast_out"], ["sub_out"], "sub"), + helper.make_node("Mul", ["sub_out", "mul_weight_2"], ["mul_mask_out"], "mul_mask"), + + # qk nodes + helper.make_node("Einsum", ["add_q_out", "mul_k_out"], ["einsum_qk_out"], "einsum_qk", equation="aecd,abcd->acbe"), + helper.make_node("Add", ["einsum_qk_out", "mul_mask_out"], ["add_qk_out"], "add_qk"), + helper.make_node("Softmax", ["add_qk_out"], ["softmax_qk_out"], "softmax_qk", axis=3), + + # v nodes + helper.make_node("Einsum", ["layernorm_out", "einsum_v_weight"], ["einsum_v_out"], "einsum_v", equation="abc,cde->abde"), + helper.make_node("Add", ["einsum_v_out", "add_v_weight"], ["add_v_out"], "add_v"), + + # qkv nodes + helper.make_node("Einsum", ["softmax_qk_out", "add_v_out"], ["einsum_qkv_1_out"], "einsum_qkv_1", equation="acbe,aecd->abcd"), + helper.make_node("Einsum", ["einsum_qkv_1_out", "einsum_qkv_weight"], ["einsum_qkv_2_out"], "einsum_qkv_2", equation="abcd,cde->abe"), + helper.make_node("Add", ["einsum_qkv_2_out", "add_qkv_weight"], ["add_qkv_out"], "add_qkv"), + helper.make_node("Add", ["add_qkv_out", "layernorm_out"], ["skip_output"], "add_skip"), + helper.make_node("LayerNormalization", ["skip_output", "layer_norm_weight", "layer_norm_bias"], ["output"], + "layernorm2", + axis=-1, + epsion=0.000009999999747378752), + ] + + initializers = [ # initializers + float_tensor('layer_norm_weight', [input_hidden_size]), + float_tensor('layer_norm_bias', [input_hidden_size]), + float_tensor('einsum_q_weight', [input_hidden_size, num_heads, head_size]), + float_tensor('einsum_k_weight', [input_hidden_size, num_heads, head_size]), + float_tensor('einsum_v_weight', [input_hidden_size, num_heads, head_size]), + float_tensor('einsum_qkv_weight', [num_heads, head_size, input_hidden_size]), + float_tensor('add_q_weight', [num_heads, head_size]), + float_tensor('add_k_weight', [num_heads, head_size]), + float_tensor('add_v_weight', [num_heads, head_size]), + float_tensor('add_qkv_weight', [input_hidden_size]), + helper.make_tensor('sub_weight', TensorProto.FLOAT, [1], [1.0]), + helper.make_tensor('mul_weight_1', TensorProto.FLOAT, [1], [-10000]), + helper.make_tensor('mul_weight_2', TensorProto.FLOAT, [1], [0.125]), + helper.make_tensor('reshape_weight_1', TensorProto.INT64, [4], [0, 0, num_heads, head_size]), + helper.make_tensor('slice_start', TensorProto.INT32, [4], [0, 0, 0, 0]), + helper.make_tensor('slice_end', TensorProto.INT32, [4], [1000000000, 1000000000, 1000000000, 1000000000]), + helper.make_tensor('slice_axes', TensorProto.INT32, [4], [0, 1, 2, 3]), + helper.make_tensor('slice_steps', TensorProto.INT32, [4], [1, 1, 1, 1]) + ] + + if has_unsqueeze_two_inputs: + initializers.append(helper.make_tensor('axes_1', TensorProto.INT64, [2], [1, 2])) + + batch_size = 1 + sequence_length = 3 + graph = helper.make_graph( + [node for node in nodes if node], + "AttentionFusionPrunedModel", #name + [ # inputs + helper.make_tensor_value_info('input_1', TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size]), + helper.make_tensor_value_info('input_2', TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size]), + helper.make_tensor_value_info('input_mask', TensorProto.FLOAT if use_float_mask else TensorProto.INT64, + [batch_size, sequence_length]) + ], + [ # outputs + helper.make_tensor_value_info('output', TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size]), + ], + initializers) + + model = helper.make_model(graph) + return model + if __name__ == "__main__": model = create_bert_attention() - onnx.save(model, "pruned_bert_attention.onnx") \ No newline at end of file + onnx.save(model, "pruned_bert_attention.onnx") + model = create_tf2onnx_attention_3d() + onnx.save(model, "bert_3d_attention.onnx") \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/test/test_attention_fusion.py b/onnxruntime/python/tools/transformers/test/test_attention_fusion.py index 105afe4d7c..2543ce06ce 100644 --- a/onnxruntime/python/tools/transformers/test/test_attention_fusion.py +++ b/onnxruntime/python/tools/transformers/test/test_attention_fusion.py @@ -8,7 +8,7 @@ import unittest import os import sys import onnx -from bert_model_generator import create_bert_attention +from bert_model_generator import create_bert_attention, create_tf2onnx_attention_3d # set path so that we could import from parent directory sys.path.append(os.path.join(os.path.dirname(__file__), '..')) @@ -28,6 +28,19 @@ class TestFusion(unittest.TestCase): 'pruned_attention_opt.onnx') expected = onnx.load(expected_model_path) self.assertEqual(str(optimized_model.model.graph), str(expected.graph)) + + def test_3d_attention_fusion_tf2onnx_model(self): + model = create_tf2onnx_attention_3d() + dir = '.' + model_path = os.path.join(dir, 'bert_3d_attention.onnx') + onnx.save(model, model_path) + optimized_model = optimize_model(model_path, model_type='bert_tf', num_heads=4, hidden_size=16) + os.remove(model_path) + + expected_model_path = os.path.join(os.path.dirname(__file__), 'test_data', 'fusion', + 'bert_3d_attention_opt.onnx') + expected = onnx.load(expected_model_path) + self.assertEqual(str(optimized_model.model.graph), str(expected.graph)) if __name__ == '__main__': diff --git a/onnxruntime/python/tools/transformers/test/test_data/fusion/bert_3d_attention_opt.onnx b/onnxruntime/python/tools/transformers/test/test_data/fusion/bert_3d_attention_opt.onnx new file mode 100644 index 0000000000..5553edfec3 Binary files /dev/null and b/onnxruntime/python/tools/transformers/test/test_data/fusion/bert_3d_attention_opt.onnx differ diff --git a/onnxruntime/python/tools/transformers/test/test_optimizer.py b/onnxruntime/python/tools/transformers/test/test_optimizer.py index b5649d6b55..d298cbbc33 100644 --- a/onnxruntime/python/tools/transformers/test/test_optimizer.py +++ b/onnxruntime/python/tools/transformers/test/test_optimizer.py @@ -340,11 +340,34 @@ class TestBertOptimization(unittest.TestCase): self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30]) @pytest.mark.slow - def test_bert_base_cased_from_tf(self): - self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 1) - self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 2) - self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 3) + def test_huggingface_bert_base_cased_from_tf2onnx(self): + self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1) + self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 2) + self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 3) + @pytest.mark.slow + def test_huggingface_distilgpt2_from_tf2onnx(self): + self._test_optimizer_on_tf_model("distilgpt2", [0, 0, 0, 0, 0, 12, 1], 1) + + @pytest.mark.slow + def test_huggingface_albert_from_tf2onnx(self): + self._test_optimizer_on_tf_model("albert-base-v1", [0, 0, 0, 0, 0, 0, 25], 1) + + @pytest.mark.slow + def test_huggingface_gpt2_from_tf2onnx(self): + self._test_optimizer_on_tf_model("gpt2", [0, 0, 0, 0, 0, 24, 1], 1, validate_model=False) + + @pytest.mark.slow + def test_huggingface_roberta_from_tf2onnx(self): + self._test_optimizer_on_tf_model("roberta-base", [0, 12, 0, 0, 0, 0, 25], 1, validate_model=False) + + @pytest.mark.slow + def test_huggingface_distilbert_from_tf2onnx(self): + self._test_optimizer_on_tf_model("distilbert-base-uncased", [0, 0, 0, 0, 0, 0, 13], 1, validate_model=False) + + @pytest.mark.slow + def test_huggingface_xlm_from_tf2onnx(self): + self._test_optimizer_on_tf_model("xlm-mlm-ende-1024", [0, 0, 0, 0, 0, 1, 12], 1, validate_model=False) if __name__ == '__main__': unittest.main()