mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Support HuggingFace Models Converted From tf2onnx in Python Script (#6985)
Support tf2onnx huggingface models in python script
This commit is contained in:
parent
335edaa2c4
commit
4fd9fef9ee
13 changed files with 448 additions and 126 deletions
|
|
@ -65,7 +65,7 @@ if "OMP_NUM_THREADS" not in os.environ:
|
|||
os.environ["OMP_NUM_THREADS"] = str(cpu_count)
|
||||
|
||||
import torch
|
||||
from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model)
|
||||
from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model, LxmertConfig)
|
||||
|
||||
|
||||
def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, batch_sizes, sequence_lengths,
|
||||
|
|
@ -128,8 +128,7 @@ def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, b
|
|||
|
||||
input_value_type = numpy.int64 if 'pt' in model_source else numpy.int32
|
||||
ort_inputs = create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names,
|
||||
input_value_type)
|
||||
|
||||
config, input_value_type)
|
||||
result_template = {
|
||||
"engine": "onnxruntime",
|
||||
"version": onnxruntime.__version__,
|
||||
|
|
@ -334,8 +333,18 @@ def run_tensorflow(use_gpu, model_names, model_class, precision, num_threads, ba
|
|||
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
|
||||
def encoder_decoder_forward():
|
||||
return model(input_ids, decoder_input_ids=input_ids, training=False)
|
||||
|
||||
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
|
||||
def lxmert_forward():
|
||||
feats = tf.random.normal([1, 1, config.visual_feat_dim])
|
||||
pos = tf.random.normal([1, 1, config.visual_pos_dim])
|
||||
return model(input_ids, visual_feats=feats, visual_pos=pos, training=False)
|
||||
|
||||
inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
|
||||
inference = encoder_forward
|
||||
if config.is_encoder_decoder:
|
||||
inference = encoder_decoder_forward
|
||||
elif isinstance(config, LxmertConfig):
|
||||
inference = lxmert_forward
|
||||
|
||||
inference()
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,10 @@ class Precision(Enum):
|
|||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
IO_BINDING_DATA_TYPE_MAP = {
|
||||
"float32": numpy.float32,
|
||||
# TODO: Add more.
|
||||
}
|
||||
def create_onnxruntime_session(onnx_model_path,
|
||||
use_gpu,
|
||||
enable_all_optimization=True,
|
||||
|
|
@ -214,7 +217,8 @@ def inference_ort_with_io_binding(ort_session,
|
|||
# Bind inputs to device
|
||||
for name in ort_inputs.keys():
|
||||
np_input = torch.from_numpy(ort_inputs[name]).to(device)
|
||||
io_binding.bind_input(name, np_input.device.type, 0, data_type, np_input.shape, np_input.data_ptr())
|
||||
input_type = IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)] if str(ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP else data_type
|
||||
io_binding.bind_input(name, np_input.device.type, 0, input_type, np_input.shape, np_input.data_ptr())
|
||||
# Bind outputs buffers with the sizes needed if not allocated already
|
||||
if len(output_buffers) == 0:
|
||||
allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)
|
||||
|
|
|
|||
|
|
@ -166,31 +166,27 @@ class FusionAttention(Fusion):
|
|||
# Check if all matrices have the same shape
|
||||
assert qw.shape == kw.shape == vw.shape
|
||||
|
||||
if len(qw.shape) != 2:
|
||||
logger.debug(f"weights for Q is expected to be 2D.")
|
||||
return None
|
||||
# All the matrices have the same shape. For 2d weights, the shapes would be [in_size, out_size].
|
||||
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
|
||||
in_size = qw.shape[0]
|
||||
out_size = np.prod(qw.shape[1:])
|
||||
|
||||
# All the matrices have the same shape (in_size, out_size)
|
||||
in_size, out_size = qw.shape
|
||||
qkv_weight = np.stack((qw, kw, vw), axis=1)
|
||||
|
||||
qb = numpy_helper.to_array(q_bias)
|
||||
kb = numpy_helper.to_array(k_bias)
|
||||
vb = numpy_helper.to_array(v_bias)
|
||||
|
||||
# 1d bias shape: [outsize,]. 2d bias shape: [a, b] where a*b = out_size
|
||||
assert qb.shape == kb.shape == vb.shape
|
||||
assert np.prod(qb.shape) == out_size
|
||||
|
||||
if out_size != hidden_size:
|
||||
logger.debug(
|
||||
f"Shape for weights of Q is {in_size, out_size}, which does not match hidden_size={hidden_size}")
|
||||
return None
|
||||
|
||||
qkv_weight = np.stack((qw, kw, vw), axis=-2)
|
||||
|
||||
qb = numpy_helper.to_array(q_bias)
|
||||
assert qb.shape == (out_size, )
|
||||
|
||||
kb = numpy_helper.to_array(k_bias)
|
||||
assert kb.shape == (out_size, )
|
||||
|
||||
vb = numpy_helper.to_array(v_bias)
|
||||
assert vb.shape == (out_size, )
|
||||
|
||||
qkv_bias = np.stack((qb, kb, vb), axis=-2)
|
||||
|
||||
qkv_bias = np.stack((qb, kb, vb), axis=0)
|
||||
attention_node_name = self.model.create_node_name('Attention')
|
||||
|
||||
weight = helper.make_tensor(name=attention_node_name + '_qkv_weight',
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import time
|
|||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Union
|
||||
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config
|
||||
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config, TFGPT2Model
|
||||
from benchmark_helper import Precision
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -33,6 +33,15 @@ class GPT2ModelNoPastState(GPT2Model):
|
|||
def forward(self, input_ids):
|
||||
return super().forward(input_ids, use_cache=False, return_dict=False)
|
||||
|
||||
class TFGPT2ModelNoPastState(TFGPT2Model):
|
||||
""" Here we wrap a class to disable past state output.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
config.use_cache = False
|
||||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids):
|
||||
return super().call(input_ids, use_cache=False)
|
||||
|
||||
class MyGPT2Model(GPT2Model):
|
||||
""" Here we wrap a class for Onnx model conversion for GPT2Model with past state.
|
||||
|
|
|
|||
|
|
@ -13,13 +13,24 @@ MODEL_CLASSES = [
|
|||
# Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type
|
||||
MODELS = {
|
||||
# BERT
|
||||
"bert-base-uncased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
|
||||
"bert-large-uncased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
|
||||
"bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
|
||||
"token_type_ids"], 11, False, "bert"),
|
||||
"bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
|
||||
|
||||
"bert-base-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
"bert-large-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
"bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-base-multilingual-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-base-multilingual-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-base-chinese": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-base-german-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-large-uncased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-large-cased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
|
||||
# "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-large-cased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
|
||||
# "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-base-german-dbmdz-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# "bert-base-german-dbmdz-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
|
||||
# todo: more models to add
|
||||
# GPT (no past state)
|
||||
"openai-gpt": (["input_ids"], 11, False, "gpt2"),
|
||||
# GPT-2 (no past state, use benchmark_gpt2.py for past_key_values)
|
||||
|
|
@ -29,7 +40,8 @@ MODELS = {
|
|||
"gpt2-xl": (["input_ids"], 11, True, "gpt2"),
|
||||
"distilgpt2": (["input_ids"], 11, False, "gpt2"),
|
||||
# Transformer-XL
|
||||
#"transfo-xl-wt103": (["input_ids"], 11, False, "bert"),
|
||||
"transfo-xl-wt103":
|
||||
(["input_ids", "mems"], 12, False, "bert"), # Models uses Einsum, which need opset version 12 and PyTorch 1.5.0 or above.
|
||||
# XLNet
|
||||
"xlnet-base-cased": (["input_ids"], 12, False, "bert"),
|
||||
"xlnet-large-cased": (["input_ids"], 12, False, "bert"),
|
||||
|
|
@ -37,14 +49,12 @@ MODELS = {
|
|||
"xlm-mlm-en-2048": (["input_ids"], 11, True, "bert"),
|
||||
"xlm-mlm-ende-1024": (["input_ids"], 11, False, "bert"),
|
||||
"xlm-mlm-enfr-1024": (["input_ids"], 11, False, "bert"),
|
||||
# XML Roberta
|
||||
"xlm-roberta-base": (["input_ids"], 12, False, "bert"),
|
||||
# RoBERTa
|
||||
"roberta-base": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"roberta-large": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"roberta-large-mnli": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"roberta-base": (["input_ids", "attention_mask"], 12, False, "bert"),
|
||||
"roberta-large": (["input_ids", "attention_mask"], 12, False, "bert"),
|
||||
"roberta-large-mnli": (["input_ids", "attention_mask"], 12, False, "bert"),
|
||||
"deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"distilroberta-base": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"distilroberta-base": (["input_ids", "attention_mask"], 12, False, "bert"),
|
||||
|
||||
# DistilBERT
|
||||
"distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
|
|
@ -63,11 +73,11 @@ MODELS = {
|
|||
"albert-xlarge-v2": (["input_ids"], 12, True, "bert"),
|
||||
#"albert-xxlarge-v2": (["input_ids"], 12, True, "bert"),
|
||||
# T5 (use benchmark_t5.py instead)
|
||||
#"t5-small": (["input_ids"], 12, False, "bert"),
|
||||
#"t5-base": (["input_ids"], 12, False, "bert"),
|
||||
#"t5-large": (["input_ids"], 12, True, "bert"),
|
||||
#"t5-3b": (["input_ids"], 12, True, "bert"),
|
||||
#"t5-11b": (["input_ids"], 12, True, "bert"),
|
||||
# "t5-small": (["input_ids", "decoder_input_ids"], 12, False, "bert"),
|
||||
# "t5-base": (["input_ids", "decoder_input_ids"], 12, False, "bert"),
|
||||
# "t5-large": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
|
||||
# "t5-3b": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
|
||||
# "t5-11b": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
|
||||
#"valhalla/t5-small-qa-qg-hl": (["input_ids"], 12, True, "bert"),
|
||||
# XLM-RoBERTa
|
||||
"xlm-roberta-base": (["input_ids"], 11, False, "bert"),
|
||||
|
|
@ -98,6 +108,20 @@ MODELS = {
|
|||
# MBart
|
||||
"facebook/mbart-large-cc25": (["input_ids"], 11, True, "bert"),
|
||||
"facebook/mbart-large-en-ro": (["input_ids"], 11, True, "bert"),
|
||||
# "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"),
|
||||
# # Longformer
|
||||
# "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"),
|
||||
# "allenai/longformer-large-4096": (["input_ids"], 12, True, "bert"),
|
||||
# "funnel-transformer/small": (["input_ids"], 12, False, "bert"),
|
||||
# "funnel-transformer/small-base": (["input_ids"], 12, False, "bert"),
|
||||
# "funnel-transformer/medium": (["input_ids"], 12, False, "bert"),
|
||||
# "funnel-transformer/medium-base": (["input_ids"], 12, False, "bert"),
|
||||
# "funnel-transformer/intermediate": (["input_ids"], 12, False, "bert"),
|
||||
# "funnel-transformer/intermediate-base": (["input_ids"], 12, False, "bert"),
|
||||
# "funnel-transformer/large": (["input_ids"], 12, True, "bert"),
|
||||
# "funnel-transformer/large-base": (["input_ids"], 12, True, "bert"),
|
||||
# "funnel-transformer/xlarge": (["input_ids"], 12, True, "bert"),
|
||||
# "funnel-transformer/xlarge-base": (["input_ids"], 12, True, "bert"),
|
||||
# Layoutlm
|
||||
"microsoft/layoutlm-base-uncased": (["input_ids"], 11, False, "bert"),
|
||||
"microsoft/layoutlm-large-uncased": (["input_ids"], 11, False, "bert"),
|
||||
|
|
@ -105,4 +129,7 @@ MODELS = {
|
|||
"squeezebert/squeezebert-uncased": (["input_ids"], 11, False, "bert"),
|
||||
"squeezebert/squeezebert-mnli": (["input_ids"], 11, False, "bert"),
|
||||
"squeezebert/squeezebert-mnli-headless": (["input_ids"], 11, False, "bert"),
|
||||
"unc-nlp/lxmert-base-uncased": (["input_ids", "visual_feats", "visual_pos"], 11, False, "bert"),
|
||||
# "google/pegasus-xsum": (["input_ids"], 11, False, "bert"),
|
||||
# "google/pegasus-large": (["input_ids"], 11, False, "bert"),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,11 +9,12 @@ import numpy
|
|||
import os
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModel
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModel, LxmertConfig, TransfoXLConfig
|
||||
from benchmark_helper import create_onnxruntime_session, Precision
|
||||
from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS
|
||||
from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState
|
||||
from quantize_helper import QuantizeHelper
|
||||
from huggingface_models import MODEL_CLASSES
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -40,7 +41,7 @@ def restore_torch_functions():
|
|||
torch.triu = torch_func["triu"]
|
||||
|
||||
|
||||
def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, data_type=numpy.int64):
|
||||
def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, data_type=numpy.int64):
|
||||
input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type)
|
||||
|
||||
inputs = {'input_ids': input_ids}
|
||||
|
|
@ -53,13 +54,23 @@ def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_name
|
|||
segment_ids = numpy.zeros([batch_size, sequence_length], dtype=data_type)
|
||||
inputs['token_type_ids'] = segment_ids
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
inputs['decoder_input_ids'] = input_ids
|
||||
|
||||
if isinstance(config, LxmertConfig):
|
||||
inputs["visual_feats"] = numpy.random.randn(1, 1, config.visual_feat_dim).astype(numpy.float32)
|
||||
inputs["visual_pos"] = numpy.random.randn(1, 1, config.visual_pos_dim).astype(numpy.float32)
|
||||
if isinstance(config, TransfoXLConfig):
|
||||
inputs["tf_transfo_xl_model/transformer/pos_emb/einsum/Einsum/inputs_1:0"] = numpy.zeros([config.hidden_size],
|
||||
dtype=numpy.float32)
|
||||
return inputs
|
||||
|
||||
|
||||
def filter_inputs(inputs, input_names):
|
||||
remaining_model_inputs = {}
|
||||
for input_name in input_names:
|
||||
remaining_model_inputs[input_name] = inputs[input_name]
|
||||
if input_name in inputs:
|
||||
remaining_model_inputs[input_name] = inputs[input_name]
|
||||
return remaining_model_inputs
|
||||
|
||||
|
||||
|
|
@ -88,7 +99,7 @@ def build_dynamic_axes(example_inputs, outputs_flatten):
|
|||
return dynamic_axes, output_names
|
||||
|
||||
|
||||
def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, fp16):
|
||||
def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, fp16, output_names=None):
|
||||
test_session = create_onnxruntime_session(onnx_model_path, use_gpu, enable_all_optimization=False)
|
||||
if test_session is None:
|
||||
logger.error(f"{onnx_model_path} is an invalid ONNX model")
|
||||
|
|
@ -97,8 +108,8 @@ def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten
|
|||
logger.info(f"{onnx_model_path} is a valid ONNX model")
|
||||
|
||||
# Compare the inference result with PyTorch or Tensorflow
|
||||
example_ort_inputs = {k: t.cpu().numpy() for k, t in example_inputs.items()}
|
||||
example_ort_outputs = test_session.run(None, example_ort_inputs)
|
||||
example_ort_inputs = {k: t.numpy() for k, t in example_inputs.items()}
|
||||
example_ort_outputs = test_session.run(output_names, example_ort_inputs)
|
||||
if len(example_outputs_flatten) != len(example_ort_outputs):
|
||||
logger.error(
|
||||
f"Number of output tensors expected {len(example_outputs_flatten)}, got {len(example_ort_outputs)}")
|
||||
|
|
@ -111,7 +122,7 @@ def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten
|
|||
|
||||
rtol = 5e-02 if fp16 else 1e-4
|
||||
atol = 1e-01 if fp16 else 1e-4
|
||||
if not numpy.allclose(example_ort_outputs[i], example_outputs_flatten[i].cpu(), rtol=rtol, atol=atol):
|
||||
if not numpy.allclose(example_ort_outputs[i], example_outputs_flatten[i].cpu().numpy(), rtol=rtol, atol=atol):
|
||||
logger.error(f"Output tensor {i} is not close: rtol={rtol}, atol={atol}")
|
||||
return False
|
||||
|
||||
|
|
@ -195,7 +206,7 @@ def optimize_onnx_model(model_name, onnx_model_path, optimized_model_path, model
|
|||
optimization_options=optimization_options,
|
||||
use_gpu=use_gpu,
|
||||
only_onnxruntime=False)
|
||||
if model_type == 'bert_keras':
|
||||
if model_type == 'bert_keras' or model_type == "bert_tf":
|
||||
opt_model.use_dynamic_axes()
|
||||
|
||||
model_fusion_statistics[optimized_model_path] = opt_model.get_fused_operator_statistics()
|
||||
|
|
@ -234,7 +245,7 @@ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_
|
|||
|
||||
if model_class_name == "GPT2ModelNoPastState":
|
||||
if is_tf_model:
|
||||
raise NotImplementedError("TFGPT2ModelNoPastState is currently not supported.")
|
||||
return TFGPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
||||
else:
|
||||
return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
||||
|
||||
|
|
@ -279,13 +290,27 @@ def load_pt_model_from_tf(model_name):
|
|||
return config, model
|
||||
|
||||
|
||||
def validate_and_optimize_onnx(model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu,
|
||||
precision, optimize_onnx, validate_onnx, use_raw_attention_mask, overwrite, config,
|
||||
model_fusion_statistics, onnx_model_path, example_inputs, example_outputs_flatten):
|
||||
def validate_and_optimize_onnx(model_name,
|
||||
use_external_data_format,
|
||||
model_type,
|
||||
onnx_dir,
|
||||
input_names,
|
||||
use_gpu,
|
||||
precision,
|
||||
optimize_onnx,
|
||||
validate_onnx,
|
||||
use_raw_attention_mask,
|
||||
overwrite,
|
||||
config,
|
||||
model_fusion_statistics,
|
||||
onnx_model_path,
|
||||
example_inputs,
|
||||
example_outputs_flatten,
|
||||
output_names=None):
|
||||
is_valid_onnx_model = True
|
||||
if validate_onnx:
|
||||
is_valid_onnx_model = validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu,
|
||||
False)
|
||||
False, output_names)
|
||||
|
||||
if optimize_onnx or precision == Precision.FLOAT16 or precision == Precision.INT8: # Use script (optimizer.py) to optimize
|
||||
optimized_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), True, use_gpu, precision,
|
||||
|
|
@ -297,7 +322,7 @@ def validate_and_optimize_onnx(model_name, use_external_data_format, model_type,
|
|||
onnx_model_path = optimized_model_path
|
||||
if validate_onnx:
|
||||
is_valid_onnx_model = validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu,
|
||||
precision == Precision.FLOAT16)
|
||||
precision == Precision.FLOAT16, output_names)
|
||||
|
||||
if precision == Precision.INT8:
|
||||
logger.info(f"Quantizing model: {onnx_model_path}")
|
||||
|
|
@ -375,46 +400,88 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma
|
|||
import tensorflow as tf
|
||||
tf.config.set_visible_devices([], 'GPU')
|
||||
|
||||
config, model = load_tf_model(model_name, model_class, cache_dir)
|
||||
|
||||
model._saved_model_inputs_spec = None
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# Fix "Using pad_token, but it is not set yet" error.
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
max_input_size = tokenizer.max_model_input_sizes[
|
||||
model_name] if model_name in tokenizer.max_model_input_sizes else 1024
|
||||
|
||||
config, model = load_tf_model(model_name, model_class, cache_dir)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
example_inputs = tokenizer.encode_plus("This is a sample input",
|
||||
return_tensors="tf",
|
||||
max_length=max_input_size,
|
||||
pad_to_max_length=True,
|
||||
padding="max_length",
|
||||
truncation=True)
|
||||
|
||||
example_inputs = filter_inputs(example_inputs, input_names)
|
||||
|
||||
example_outputs = model(example_inputs, training=False).to_tuple()
|
||||
if config.is_encoder_decoder:
|
||||
example_inputs["decoder_input_ids"] = tokenizer.encode_plus("This is a sample input",
|
||||
return_tensors="tf",
|
||||
max_length=max_input_size,
|
||||
padding="max_length",
|
||||
truncation=True).input_ids
|
||||
if model_name == "unc-nlp/lxmert-base-uncased":
|
||||
example_inputs["visual_feats"] = tf.random.normal([1, 1, config.visual_feat_dim])
|
||||
example_inputs["visual_pos"] = tf.random.normal([1, 1, config.visual_pos_dim])
|
||||
|
||||
# Flatten is needed for gpt2 and distilgpt2.
|
||||
example_outputs_flatten = flatten(example_outputs)
|
||||
example_outputs_flatten = update_flatten_list(example_outputs_flatten, [])
|
||||
try:
|
||||
# Use no past state for these models
|
||||
if config.use_cache:
|
||||
config.use_cache = False
|
||||
except:
|
||||
pass
|
||||
|
||||
example_outputs = model(example_inputs, training=False)
|
||||
output_names = None
|
||||
|
||||
# For xlnet models, only compare the last_hidden_state output.
|
||||
if model_name == "xlnet-base-cased" or model_name == "xlnet-large-cased":
|
||||
output_names = ["last_hidden_state"]
|
||||
example_outputs = example_outputs["last_hidden_state"]
|
||||
|
||||
# Flatten is needed for gpt2 and distilgpt2. Output name sorting is needed for tf2onnx outputs to match onnx outputs.
|
||||
from tensorflow.python.util import nest
|
||||
example_outputs_flatten = nest.flatten(example_outputs)
|
||||
|
||||
onnx_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), False, use_gpu, precision, False,
|
||||
use_external_data_format)
|
||||
tf_internal_model_path = onnx_model_path[:-5] if use_external_data_format else onnx_model_path
|
||||
|
||||
if overwrite or not os.path.exists(onnx_model_path):
|
||||
if overwrite or not os.path.exists(tf_internal_model_path):
|
||||
logger.info("Exporting ONNX model to {}".format(onnx_model_path))
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
if not use_external_data_format:
|
||||
Path(tf_internal_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
import tf2onnx, zipfile
|
||||
tf2onnx.logging.set_level(tf2onnx.logging.ERROR)
|
||||
specs = []
|
||||
for name, value in example_inputs.items():
|
||||
dims = [None] * len(value.shape)
|
||||
specs.append(tf.TensorSpec(tuple(dims), value.dtype, name=name))
|
||||
_, _ = tf2onnx.convert.from_keras(model,
|
||||
input_signature=tuple(specs),
|
||||
opset=opset_version,
|
||||
large_model=use_external_data_format,
|
||||
output_path=tf_internal_model_path)
|
||||
if use_external_data_format:
|
||||
# need to unpack the zip for run_onnxruntime()
|
||||
with zipfile.ZipFile(tf_internal_model_path, 'r') as z:
|
||||
z.extractall(os.path.dirname(tf_internal_model_path))
|
||||
tf_internal_model_path = os.path.join(os.path.dirname(tf_internal_model_path), "__MODEL_PROTO.onnx")
|
||||
if os.path.exists(onnx_model_path):
|
||||
os.remove(onnx_model_path)
|
||||
os.rename(tf_internal_model_path, onnx_model_path)
|
||||
|
||||
import keras2onnx
|
||||
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=opset_version)
|
||||
keras2onnx.save_model(onnx_model, onnx_model_path)
|
||||
else:
|
||||
logger.info(f"Skip export since model existed: {onnx_model_path}")
|
||||
|
||||
model_type = model_type + '_keras'
|
||||
|
||||
model_type = model_type + '_tf'
|
||||
onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
|
||||
model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimize_onnx,
|
||||
validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path,
|
||||
example_inputs, example_outputs_flatten)
|
||||
example_inputs, example_outputs_flatten, output_names)
|
||||
|
||||
return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size
|
||||
|
|
|
|||
|
|
@ -138,7 +138,8 @@ class BertOnnxModelKeras(BertOnnxModelTF):
|
|||
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
||||
logger.debug("Create an Attention node.")
|
||||
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v,
|
||||
add_q, add_k, add_v, parent.output[0],
|
||||
add_q, add_k, add_v, self.num_heads,
|
||||
self.hidden_size, parent.output[0],
|
||||
reshape_qkv.output[0])
|
||||
if attention_node is None:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -29,6 +29,21 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
self.remove_nodes(nodes_to_remove)
|
||||
logger.info(f"Removed Identity count: {len(nodes_to_remove)}")
|
||||
|
||||
def match_mask_path(self, add_or_sub_before_softmax):
|
||||
mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Reshape', 'Cast'],
|
||||
[1, None, 1, 0])
|
||||
if mask_nodes is not None:
|
||||
return mask_nodes
|
||||
|
||||
mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Slice', 'Unsqueeze'],
|
||||
[1, 0, 1, 0, 0])
|
||||
if mask_nodes is not None:
|
||||
return mask_nodes
|
||||
|
||||
mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'],
|
||||
[1, None, 1, 0, 0])
|
||||
return mask_nodes
|
||||
|
||||
def fuse_mask(self):
|
||||
nodes_to_remove = []
|
||||
for node in self.nodes():
|
||||
|
|
@ -85,16 +100,14 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
nodes_to_remove = []
|
||||
for node in self.nodes():
|
||||
if node.op_type == 'Mul' and self.has_constant_input(node, -10000):
|
||||
mask_path = self.match_parent_path(node, ['Sub', 'Unsqueeze', 'Mul', 'Cast', 'Reshape', 'Cast'],
|
||||
[0, 1, 0, 1, 0, 0])
|
||||
mask_path = self.match_parent_path(node, ['Sub', 'Cast', 'Slice', 'Unsqueeze'], [0, 1, 0, 0])
|
||||
if mask_path is None:
|
||||
continue
|
||||
sub_node, unsqueeze_node, mul_node, cast_node_0, reshape_node_0, cast_node_1 = mask_path
|
||||
sub_node, cast_node, slice_node, unsqueeze_node = mask_path
|
||||
|
||||
mask_input_name = self.attention_mask.get_first_mask()
|
||||
|
||||
if cast_node_1.input[0] != mask_input_name:
|
||||
print("Cast input {} is not mask input{}".format(cast_node_1.input[0], mask_input_name))
|
||||
if unsqueeze_node.input[0] != mask_input_name:
|
||||
print("Cast input {} is not mask input {}".format(unsqueeze_node.input[0], mask_input_name))
|
||||
continue
|
||||
|
||||
unsqueeze_added_1 = onnx.helper.make_node('Unsqueeze',
|
||||
|
|
@ -109,13 +122,14 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
name='Mask_UnSqueeze_2',
|
||||
axes=[2])
|
||||
|
||||
#self.replace_node_input(cast_node, cast_node.input[0], 'mask_fuse_unsqueeze2_output')
|
||||
cast_node_2 = onnx.helper.make_node('Cast',
|
||||
inputs=['mask_fuse_unsqueeze2_output'],
|
||||
outputs=['mask_fuse_cast_output'])
|
||||
cast_node_2.attribute.extend([onnx.helper.make_attribute("to", 1)])
|
||||
self.replace_node_input(sub_node, sub_node.input[1], 'mask_fuse_cast_output')
|
||||
|
||||
nodes_to_remove.extend([unsqueeze_node, mul_node, cast_node_0, reshape_node_0, cast_node_1])
|
||||
nodes_to_remove.extend([slice_node, unsqueeze_node, cast_node])
|
||||
self.add_node(unsqueeze_added_1)
|
||||
self.add_node(unsqueeze_added_2)
|
||||
self.add_node(cast_node_2)
|
||||
|
|
@ -360,9 +374,22 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
nodes_to_remove = []
|
||||
attention_count = 0
|
||||
|
||||
start_nodes = []
|
||||
skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
|
||||
for normalize_node in skip_layer_norm_nodes:
|
||||
layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
|
||||
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
|
||||
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
|
||||
start_nodes.extend(skip_layer_norm_nodes)
|
||||
start_nodes.extend(layer_norm_nodes)
|
||||
|
||||
for normalize_node in start_nodes:
|
||||
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
||||
if normalize_node.op_type == 'LayerNormalization':
|
||||
add_before_layernorm = self.match_parent(normalize_node, 'Add', 0)
|
||||
if add_before_layernorm is not None:
|
||||
normalize_node = add_before_layernorm
|
||||
else:
|
||||
continue
|
||||
parent = self.get_parent(normalize_node, 1)
|
||||
if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]:
|
||||
parent = self.get_parent(normalize_node, 0)
|
||||
|
|
@ -376,50 +403,63 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
qkv_nodes = self.match_parent_path(normalize_node, ['MatMul', 'Reshape', 'Transpose', 'MatMul'],
|
||||
[1, 0, 0, 0])
|
||||
if qkv_nodes is None:
|
||||
logger.debug("Failed to match qkv nodes")
|
||||
continue
|
||||
qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'Einsum', 'Einsum'], [0, 0, 0])
|
||||
if qkv_nodes is None:
|
||||
logger.debug("Failed to match qkv nodes")
|
||||
continue
|
||||
|
||||
(reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes[-3:]
|
||||
matmul_qkv = qkv_nodes[-1]
|
||||
v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
|
||||
if v_nodes is None:
|
||||
logger.debug("Failed to match v path")
|
||||
continue
|
||||
v_nodes = self.match_parent_path(matmul_qkv, ['Add', 'Einsum'], [1, 0])
|
||||
if v_nodes is None:
|
||||
logger.debug("Failed to match v path")
|
||||
continue
|
||||
|
||||
(transpose_v, reshape_v, add_v, matmul_v) = v_nodes
|
||||
add_v = v_nodes[-2]
|
||||
matmul_v = v_nodes[-1]
|
||||
qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', "Mul", 'MatMul'], [0, 0, 0, 0])
|
||||
if qk_nodes is None:
|
||||
logger.debug("Failed to match qk_paths")
|
||||
continue
|
||||
(softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes
|
||||
qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'Einsum'], [0, 0, 0])
|
||||
if qk_nodes is None:
|
||||
logger.debug("Failed to match qk_paths")
|
||||
continue
|
||||
matmul_qk = qk_nodes[-1]
|
||||
|
||||
q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("Failed to match q path")
|
||||
continue
|
||||
(transpose_q, reshape_q, add_q, matmul_q) = q_nodes
|
||||
q_nodes = self.match_parent_path(matmul_qk, ['Add', 'Einsum'], [0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("Failed to match q path")
|
||||
continue
|
||||
add_q = q_nodes[-2]
|
||||
matmul_q = q_nodes[-1]
|
||||
|
||||
k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
|
||||
if k_nodes is None:
|
||||
logger.debug("Failed to match k path")
|
||||
continue
|
||||
(transpose_k, reshape_k, add_k, matmul_k) = k_nodes
|
||||
|
||||
mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Unsqueeze'], [1, 0, 1])
|
||||
if mask_nodes is None:
|
||||
mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Mul'], [1, 0, 1, 0, 0])
|
||||
if mask_nodes is None:
|
||||
logger.debug("Failed to match mask path")
|
||||
k_nodes = self.match_parent_path(matmul_qk, ['Mul', 'Add', 'Einsum'], [1, 0, 0])
|
||||
if k_nodes is None:
|
||||
logger.debug("Failed to match k path")
|
||||
continue
|
||||
add_k = k_nodes[-2]
|
||||
matmul_k = k_nodes[-1]
|
||||
|
||||
mask_nodes = self.match_mask_path(qk_nodes[1])
|
||||
|
||||
if mask_nodes is None:
|
||||
logger.debug("Cannot find mask_nodes.")
|
||||
continue
|
||||
|
||||
if not self.has_constant_input(mask_nodes[1], 1):
|
||||
logger.debug("Sub node expected to have an input with constant value 1.0.")
|
||||
continue
|
||||
|
||||
# add a squeeze node to convert a 3-d mask to 2-d
|
||||
squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0])
|
||||
squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0]) or self.match_parent_path(
|
||||
mask_nodes[-1], ['Expand'], [0])
|
||||
squeeze_node_name = "Squeeze_3d_to_2d_mask"
|
||||
squeeze_output_name = squeeze_node_name + "_output"
|
||||
if squeeze_node is None and len(mask_nodes) == 5:
|
||||
if squeeze_node is None and len(mask_nodes) == 5 and self.find_graph_input(mask_nodes[-1].input[0]) is None:
|
||||
mask_input = mask_nodes[-1].input[1]
|
||||
self.add_node(
|
||||
helper.make_node("Squeeze", [mask_input], [squeeze_output_name], squeeze_node_name, axes=[1]))
|
||||
|
|
@ -427,11 +467,32 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
|
||||
is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node)
|
||||
if is_same_root:
|
||||
mask_index = self.attention_mask.process_mask(squeeze_output_name)
|
||||
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
||||
logger.debug("Create an Attention node.")
|
||||
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v,
|
||||
add_q, add_k, add_v, parent.output[0],
|
||||
reshape_qkv.output[0])
|
||||
# For tf models, q and v are flipped.
|
||||
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_k, matmul_q, matmul_v,
|
||||
add_k, add_q, add_v, self.num_heads,
|
||||
self.hidden_size, parent.output[0],
|
||||
qkv_nodes[2].output[0])
|
||||
if attention_node is None:
|
||||
continue
|
||||
|
||||
if qkv_nodes[1].op_type == 'Einsum':
|
||||
# add reshape before einsum
|
||||
tensor = helper.make_tensor(name=qkv_nodes[1].name + "_newshape",
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[4],
|
||||
vals=np.int64(
|
||||
[[0, 0, self.num_heads,
|
||||
int(self.hidden_size / self.num_heads)]]).tobytes(),
|
||||
raw=True)
|
||||
self.add_initializer(tensor)
|
||||
reshape_ = helper.make_node("Reshape",
|
||||
inputs=[attention_node.output[0], qkv_nodes[1].name + "_newshape"],
|
||||
outputs=[qkv_nodes[1].name + "_reshape_output"],
|
||||
name=qkv_nodes[1].name + "_reshape")
|
||||
qkv_nodes[1].input[0] = qkv_nodes[1].name + "_reshape_output"
|
||||
self.add_node(reshape_)
|
||||
if parent.op_type == 'Reshape':
|
||||
# Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input
|
||||
hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1]
|
||||
|
|
@ -443,13 +504,11 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
self.add_initializer(tensor)
|
||||
parent.input[1] = parent.name + "_modified"
|
||||
|
||||
if attention_node is None:
|
||||
continue
|
||||
|
||||
|
||||
self.add_node(attention_node)
|
||||
attention_count += 1
|
||||
|
||||
nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
|
||||
nodes_to_remove.extend(qkv_nodes[2:])
|
||||
nodes_to_remove.extend(qk_nodes)
|
||||
nodes_to_remove.extend(q_nodes)
|
||||
nodes_to_remove.extend(k_nodes)
|
||||
|
|
@ -466,7 +525,20 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
self.remove_identity()
|
||||
self.process_embedding()
|
||||
#TODO: remove fuse mask since we have embedding fused so fuse_attention shall handle the mask nodes.
|
||||
self.fuse_mask()
|
||||
# self.fuse_mask()
|
||||
self.skip_reshape()
|
||||
|
||||
def skip_reshape(self):
|
||||
count = 0
|
||||
reshape_nodes = self.get_nodes_by_op_type("Reshape")
|
||||
for reshape_node in reshape_nodes:
|
||||
parent = self.get_parent(reshape_node, 0)
|
||||
if parent is not None and parent.op_type == "Reshape":
|
||||
reshape_node.input[0] = parent.input[0]
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Skip consequent Reshape count: {count}")
|
||||
|
||||
def remove_reshape_before_first_attention(self):
|
||||
attention_nodes = self.get_nodes_by_op_type("Attention")
|
||||
|
|
@ -475,7 +547,7 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
if path is None:
|
||||
continue
|
||||
logger.info("Remove Reshape before first Attention node.")
|
||||
reshape, embed = path
|
||||
reshape, _ = path
|
||||
self.replace_input_of_all_nodes(reshape.output[0], reshape.input[0])
|
||||
self.remove_node(reshape)
|
||||
break
|
||||
|
|
|
|||
|
|
@ -39,7 +39,8 @@ MODEL_CLASSES = {
|
|||
"bert": (BertOnnxModel, "pytorch", True),
|
||||
"bert_tf": (BertOnnxModelTF, "tf2onnx", False),
|
||||
"bert_keras": (BertOnnxModelKeras, "keras2onnx", False),
|
||||
"gpt2": (Gpt2OnnxModel, "pytorch", True)
|
||||
"gpt2": (Gpt2OnnxModel, "pytorch", True),
|
||||
"gpt2_tf": (Gpt2OnnxModel, 'tf2onnx', False) # might add a class for GPT2OnnxModel for TF later.
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -127,7 +127,107 @@ def create_bert_attention(input_hidden_size=16, pruned_num_heads=2, pruned_head_
|
|||
model = helper.make_model(graph)
|
||||
return model
|
||||
|
||||
def create_tf2onnx_attention_3d(input_hidden_size=16, num_heads=4, head_size=4, use_float_mask=False):
|
||||
# unsqueeze in opset version 13 has two inputs (axis is moved from attribute to input).
|
||||
has_unsqueeze_two_inputs = (version.parse(onnx.__version__) >= version.parse('1.8.0'))
|
||||
|
||||
# nodes in attention subgraph
|
||||
nodes = [
|
||||
helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm"),
|
||||
helper.make_node("LayerNormalization", ["layernorm_input", "layer_norm_weight", "layer_norm_bias"],
|
||||
["layernorm_out"],
|
||||
"layernorm",
|
||||
axis=-1,
|
||||
epsion=0.000009999999747378752),
|
||||
|
||||
# q nodes
|
||||
helper.make_node("Einsum", ["layernorm_out", "einsum_q_weight"], ["einsum_q_out"], "einsum_q", equation="abc,cde->abde"),
|
||||
helper.make_node("Add", ["einsum_q_out", "add_q_weight"], ["add_q_out"], "add_q"),
|
||||
|
||||
# k nodes
|
||||
helper.make_node("Einsum", ["layernorm_out", "einsum_k_weight"], ["einsum_k_out"], "einsum_k", equation="abc,cde->abde"),
|
||||
helper.make_node("Add", ["einsum_k_out", "add_k_weight"], ["add_k_out"], "add_k"),
|
||||
helper.make_node("Mul", ["add_k_out", "mul_weight_1"], ["mul_k_out"], "mul_k"),
|
||||
|
||||
# mask nodes
|
||||
helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0") if has_unsqueeze_two_inputs \
|
||||
else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1, 2]),
|
||||
helper.make_node("Slice", ["unsqueeze0_out", "slice_start", "slice_end", "slice_axes", "slice_steps"], ["slice_out"], "slice"),
|
||||
|
||||
# when attention_mask is float type, no need to cast
|
||||
helper.make_node("Cast", ["slice_out"], ["cast_out"], "cast", to=1) if not use_float_mask else None,
|
||||
helper.make_node("Sub", ["sub_weight", "unsqueeze1_out" if use_float_mask else "cast_out"], ["sub_out"], "sub"),
|
||||
helper.make_node("Mul", ["sub_out", "mul_weight_2"], ["mul_mask_out"], "mul_mask"),
|
||||
|
||||
# qk nodes
|
||||
helper.make_node("Einsum", ["add_q_out", "mul_k_out"], ["einsum_qk_out"], "einsum_qk", equation="aecd,abcd->acbe"),
|
||||
helper.make_node("Add", ["einsum_qk_out", "mul_mask_out"], ["add_qk_out"], "add_qk"),
|
||||
helper.make_node("Softmax", ["add_qk_out"], ["softmax_qk_out"], "softmax_qk", axis=3),
|
||||
|
||||
# v nodes
|
||||
helper.make_node("Einsum", ["layernorm_out", "einsum_v_weight"], ["einsum_v_out"], "einsum_v", equation="abc,cde->abde"),
|
||||
helper.make_node("Add", ["einsum_v_out", "add_v_weight"], ["add_v_out"], "add_v"),
|
||||
|
||||
# qkv nodes
|
||||
helper.make_node("Einsum", ["softmax_qk_out", "add_v_out"], ["einsum_qkv_1_out"], "einsum_qkv_1", equation="acbe,aecd->abcd"),
|
||||
helper.make_node("Einsum", ["einsum_qkv_1_out", "einsum_qkv_weight"], ["einsum_qkv_2_out"], "einsum_qkv_2", equation="abcd,cde->abe"),
|
||||
helper.make_node("Add", ["einsum_qkv_2_out", "add_qkv_weight"], ["add_qkv_out"], "add_qkv"),
|
||||
helper.make_node("Add", ["add_qkv_out", "layernorm_out"], ["skip_output"], "add_skip"),
|
||||
helper.make_node("LayerNormalization", ["skip_output", "layer_norm_weight", "layer_norm_bias"], ["output"],
|
||||
"layernorm2",
|
||||
axis=-1,
|
||||
epsion=0.000009999999747378752),
|
||||
]
|
||||
|
||||
initializers = [ # initializers
|
||||
float_tensor('layer_norm_weight', [input_hidden_size]),
|
||||
float_tensor('layer_norm_bias', [input_hidden_size]),
|
||||
float_tensor('einsum_q_weight', [input_hidden_size, num_heads, head_size]),
|
||||
float_tensor('einsum_k_weight', [input_hidden_size, num_heads, head_size]),
|
||||
float_tensor('einsum_v_weight', [input_hidden_size, num_heads, head_size]),
|
||||
float_tensor('einsum_qkv_weight', [num_heads, head_size, input_hidden_size]),
|
||||
float_tensor('add_q_weight', [num_heads, head_size]),
|
||||
float_tensor('add_k_weight', [num_heads, head_size]),
|
||||
float_tensor('add_v_weight', [num_heads, head_size]),
|
||||
float_tensor('add_qkv_weight', [input_hidden_size]),
|
||||
helper.make_tensor('sub_weight', TensorProto.FLOAT, [1], [1.0]),
|
||||
helper.make_tensor('mul_weight_1', TensorProto.FLOAT, [1], [-10000]),
|
||||
helper.make_tensor('mul_weight_2', TensorProto.FLOAT, [1], [0.125]),
|
||||
helper.make_tensor('reshape_weight_1', TensorProto.INT64, [4], [0, 0, num_heads, head_size]),
|
||||
helper.make_tensor('slice_start', TensorProto.INT32, [4], [0, 0, 0, 0]),
|
||||
helper.make_tensor('slice_end', TensorProto.INT32, [4], [1000000000, 1000000000, 1000000000, 1000000000]),
|
||||
helper.make_tensor('slice_axes', TensorProto.INT32, [4], [0, 1, 2, 3]),
|
||||
helper.make_tensor('slice_steps', TensorProto.INT32, [4], [1, 1, 1, 1])
|
||||
]
|
||||
|
||||
if has_unsqueeze_two_inputs:
|
||||
initializers.append(helper.make_tensor('axes_1', TensorProto.INT64, [2], [1, 2]))
|
||||
|
||||
batch_size = 1
|
||||
sequence_length = 3
|
||||
graph = helper.make_graph(
|
||||
[node for node in nodes if node],
|
||||
"AttentionFusionPrunedModel", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('input_1', TensorProto.FLOAT,
|
||||
[batch_size, sequence_length, input_hidden_size]),
|
||||
helper.make_tensor_value_info('input_2', TensorProto.FLOAT,
|
||||
[batch_size, sequence_length, input_hidden_size]),
|
||||
helper.make_tensor_value_info('input_mask', TensorProto.FLOAT if use_float_mask else TensorProto.INT64,
|
||||
[batch_size, sequence_length])
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('output', TensorProto.FLOAT,
|
||||
[batch_size, sequence_length, input_hidden_size]),
|
||||
],
|
||||
initializers)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = create_bert_attention()
|
||||
onnx.save(model, "pruned_bert_attention.onnx")
|
||||
onnx.save(model, "pruned_bert_attention.onnx")
|
||||
model = create_tf2onnx_attention_3d()
|
||||
onnx.save(model, "bert_3d_attention.onnx")
|
||||
|
|
@ -8,7 +8,7 @@ import unittest
|
|||
import os
|
||||
import sys
|
||||
import onnx
|
||||
from bert_model_generator import create_bert_attention
|
||||
from bert_model_generator import create_bert_attention, create_tf2onnx_attention_3d
|
||||
|
||||
# set path so that we could import from parent directory
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
|
@ -28,6 +28,19 @@ class TestFusion(unittest.TestCase):
|
|||
'pruned_attention_opt.onnx')
|
||||
expected = onnx.load(expected_model_path)
|
||||
self.assertEqual(str(optimized_model.model.graph), str(expected.graph))
|
||||
|
||||
def test_3d_attention_fusion_tf2onnx_model(self):
|
||||
model = create_tf2onnx_attention_3d()
|
||||
dir = '.'
|
||||
model_path = os.path.join(dir, 'bert_3d_attention.onnx')
|
||||
onnx.save(model, model_path)
|
||||
optimized_model = optimize_model(model_path, model_type='bert_tf', num_heads=4, hidden_size=16)
|
||||
os.remove(model_path)
|
||||
|
||||
expected_model_path = os.path.join(os.path.dirname(__file__), 'test_data', 'fusion',
|
||||
'bert_3d_attention_opt.onnx')
|
||||
expected = onnx.load(expected_model_path)
|
||||
self.assertEqual(str(optimized_model.model.graph), str(expected.graph))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -340,11 +340,34 @@ class TestBertOptimization(unittest.TestCase):
|
|||
self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30])
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_bert_base_cased_from_tf(self):
|
||||
self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 1)
|
||||
self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 2)
|
||||
self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 3)
|
||||
def test_huggingface_bert_base_cased_from_tf2onnx(self):
|
||||
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1)
|
||||
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 2)
|
||||
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 3)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_huggingface_distilgpt2_from_tf2onnx(self):
|
||||
self._test_optimizer_on_tf_model("distilgpt2", [0, 0, 0, 0, 0, 12, 1], 1)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_huggingface_albert_from_tf2onnx(self):
|
||||
self._test_optimizer_on_tf_model("albert-base-v1", [0, 0, 0, 0, 0, 0, 25], 1)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_huggingface_gpt2_from_tf2onnx(self):
|
||||
self._test_optimizer_on_tf_model("gpt2", [0, 0, 0, 0, 0, 24, 1], 1, validate_model=False)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_huggingface_roberta_from_tf2onnx(self):
|
||||
self._test_optimizer_on_tf_model("roberta-base", [0, 12, 0, 0, 0, 0, 25], 1, validate_model=False)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_huggingface_distilbert_from_tf2onnx(self):
|
||||
self._test_optimizer_on_tf_model("distilbert-base-uncased", [0, 0, 0, 0, 0, 0, 13], 1, validate_model=False)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_huggingface_xlm_from_tf2onnx(self):
|
||||
self._test_optimizer_on_tf_model("xlm-mlm-ende-1024", [0, 0, 0, 0, 0, 1, 12], 1, validate_model=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue