Support HuggingFace Models Converted From tf2onnx in Python Script (#6985)

Support tf2onnx huggingface models in python script
This commit is contained in:
Cecilia Liu 2021-03-17 15:33:57 -07:00 committed by GitHub
parent 335edaa2c4
commit 4fd9fef9ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 448 additions and 126 deletions

View file

@ -65,7 +65,7 @@ if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = str(cpu_count)
import torch
from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model)
from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model, LxmertConfig)
def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, batch_sizes, sequence_lengths,
@ -128,8 +128,7 @@ def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, b
input_value_type = numpy.int64 if 'pt' in model_source else numpy.int32
ort_inputs = create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names,
input_value_type)
config, input_value_type)
result_template = {
"engine": "onnxruntime",
"version": onnxruntime.__version__,
@ -334,8 +333,18 @@ def run_tensorflow(use_gpu, model_names, model_class, precision, num_threads, ba
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
def encoder_decoder_forward():
return model(input_ids, decoder_input_ids=input_ids, training=False)
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
def lxmert_forward():
feats = tf.random.normal([1, 1, config.visual_feat_dim])
pos = tf.random.normal([1, 1, config.visual_pos_dim])
return model(input_ids, visual_feats=feats, visual_pos=pos, training=False)
inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
inference = encoder_forward
if config.is_encoder_decoder:
inference = encoder_decoder_forward
elif isinstance(config, LxmertConfig):
inference = lxmert_forward
inference()

View file

@ -30,7 +30,10 @@ class Precision(Enum):
def __str__(self):
return self.value
IO_BINDING_DATA_TYPE_MAP = {
"float32": numpy.float32,
# TODO: Add more.
}
def create_onnxruntime_session(onnx_model_path,
use_gpu,
enable_all_optimization=True,
@ -214,7 +217,8 @@ def inference_ort_with_io_binding(ort_session,
# Bind inputs to device
for name in ort_inputs.keys():
np_input = torch.from_numpy(ort_inputs[name]).to(device)
io_binding.bind_input(name, np_input.device.type, 0, data_type, np_input.shape, np_input.data_ptr())
input_type = IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)] if str(ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP else data_type
io_binding.bind_input(name, np_input.device.type, 0, input_type, np_input.shape, np_input.data_ptr())
# Bind outputs buffers with the sizes needed if not allocated already
if len(output_buffers) == 0:
allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)

View file

@ -166,31 +166,27 @@ class FusionAttention(Fusion):
# Check if all matrices have the same shape
assert qw.shape == kw.shape == vw.shape
if len(qw.shape) != 2:
logger.debug(f"weights for Q is expected to be 2D.")
return None
# All the matrices have the same shape. For 2d weights, the shapes would be [in_size, out_size].
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
in_size = qw.shape[0]
out_size = np.prod(qw.shape[1:])
# All the matrices have the same shape (in_size, out_size)
in_size, out_size = qw.shape
qkv_weight = np.stack((qw, kw, vw), axis=1)
qb = numpy_helper.to_array(q_bias)
kb = numpy_helper.to_array(k_bias)
vb = numpy_helper.to_array(v_bias)
# 1d bias shape: [outsize,]. 2d bias shape: [a, b] where a*b = out_size
assert qb.shape == kb.shape == vb.shape
assert np.prod(qb.shape) == out_size
if out_size != hidden_size:
logger.debug(
f"Shape for weights of Q is {in_size, out_size}, which does not match hidden_size={hidden_size}")
return None
qkv_weight = np.stack((qw, kw, vw), axis=-2)
qb = numpy_helper.to_array(q_bias)
assert qb.shape == (out_size, )
kb = numpy_helper.to_array(k_bias)
assert kb.shape == (out_size, )
vb = numpy_helper.to_array(v_bias)
assert vb.shape == (out_size, )
qkv_bias = np.stack((qb, kb, vb), axis=-2)
qkv_bias = np.stack((qb, kb, vb), axis=0)
attention_node_name = self.model.create_node_name('Attention')
weight = helper.make_tensor(name=attention_node_name + '_qkv_weight',

View file

@ -14,7 +14,7 @@ import time
import re
from pathlib import Path
from typing import List, Dict, Tuple, Union
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config, TFGPT2Model
from benchmark_helper import Precision
logger = logging.getLogger(__name__)
@ -33,6 +33,15 @@ class GPT2ModelNoPastState(GPT2Model):
def forward(self, input_ids):
return super().forward(input_ids, use_cache=False, return_dict=False)
class TFGPT2ModelNoPastState(TFGPT2Model):
""" Here we wrap a class to disable past state output.
"""
def __init__(self, config):
config.use_cache = False
super().__init__(config)
def forward(self, input_ids):
return super().call(input_ids, use_cache=False)
class MyGPT2Model(GPT2Model):
""" Here we wrap a class for Onnx model conversion for GPT2Model with past state.

View file

@ -13,13 +13,24 @@ MODEL_CLASSES = [
# Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type
MODELS = {
# BERT
"bert-base-uncased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
"bert-large-uncased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
"bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
"bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
"token_type_ids"], 11, False, "bert"),
"bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
"bert-base-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
"bert-large-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
"bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-base-multilingual-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-base-multilingual-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-base-chinese": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-base-german-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-large-uncased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-large-cased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
# "token_type_ids"], 12, False, "bert"),
# "bert-large-cased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
# "token_type_ids"], 12, False, "bert"),
# "bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-base-german-dbmdz-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# "bert-base-german-dbmdz-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
# todo: more models to add
# GPT (no past state)
"openai-gpt": (["input_ids"], 11, False, "gpt2"),
# GPT-2 (no past state, use benchmark_gpt2.py for past_key_values)
@ -29,7 +40,8 @@ MODELS = {
"gpt2-xl": (["input_ids"], 11, True, "gpt2"),
"distilgpt2": (["input_ids"], 11, False, "gpt2"),
# Transformer-XL
#"transfo-xl-wt103": (["input_ids"], 11, False, "bert"),
"transfo-xl-wt103":
(["input_ids", "mems"], 12, False, "bert"), # Models uses Einsum, which need opset version 12 and PyTorch 1.5.0 or above.
# XLNet
"xlnet-base-cased": (["input_ids"], 12, False, "bert"),
"xlnet-large-cased": (["input_ids"], 12, False, "bert"),
@ -37,14 +49,12 @@ MODELS = {
"xlm-mlm-en-2048": (["input_ids"], 11, True, "bert"),
"xlm-mlm-ende-1024": (["input_ids"], 11, False, "bert"),
"xlm-mlm-enfr-1024": (["input_ids"], 11, False, "bert"),
# XML Roberta
"xlm-roberta-base": (["input_ids"], 12, False, "bert"),
# RoBERTa
"roberta-base": (["input_ids", "attention_mask"], 11, False, "bert"),
"roberta-large": (["input_ids", "attention_mask"], 11, False, "bert"),
"roberta-large-mnli": (["input_ids", "attention_mask"], 11, False, "bert"),
"roberta-base": (["input_ids", "attention_mask"], 12, False, "bert"),
"roberta-large": (["input_ids", "attention_mask"], 12, False, "bert"),
"roberta-large-mnli": (["input_ids", "attention_mask"], 12, False, "bert"),
"deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 11, False, "bert"),
"distilroberta-base": (["input_ids", "attention_mask"], 11, False, "bert"),
"distilroberta-base": (["input_ids", "attention_mask"], 12, False, "bert"),
# DistilBERT
"distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"),
@ -63,11 +73,11 @@ MODELS = {
"albert-xlarge-v2": (["input_ids"], 12, True, "bert"),
#"albert-xxlarge-v2": (["input_ids"], 12, True, "bert"),
# T5 (use benchmark_t5.py instead)
#"t5-small": (["input_ids"], 12, False, "bert"),
#"t5-base": (["input_ids"], 12, False, "bert"),
#"t5-large": (["input_ids"], 12, True, "bert"),
#"t5-3b": (["input_ids"], 12, True, "bert"),
#"t5-11b": (["input_ids"], 12, True, "bert"),
# "t5-small": (["input_ids", "decoder_input_ids"], 12, False, "bert"),
# "t5-base": (["input_ids", "decoder_input_ids"], 12, False, "bert"),
# "t5-large": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
# "t5-3b": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
# "t5-11b": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
#"valhalla/t5-small-qa-qg-hl": (["input_ids"], 12, True, "bert"),
# XLM-RoBERTa
"xlm-roberta-base": (["input_ids"], 11, False, "bert"),
@ -98,6 +108,20 @@ MODELS = {
# MBart
"facebook/mbart-large-cc25": (["input_ids"], 11, True, "bert"),
"facebook/mbart-large-en-ro": (["input_ids"], 11, True, "bert"),
# "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"),
# # Longformer
# "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"),
# "allenai/longformer-large-4096": (["input_ids"], 12, True, "bert"),
# "funnel-transformer/small": (["input_ids"], 12, False, "bert"),
# "funnel-transformer/small-base": (["input_ids"], 12, False, "bert"),
# "funnel-transformer/medium": (["input_ids"], 12, False, "bert"),
# "funnel-transformer/medium-base": (["input_ids"], 12, False, "bert"),
# "funnel-transformer/intermediate": (["input_ids"], 12, False, "bert"),
# "funnel-transformer/intermediate-base": (["input_ids"], 12, False, "bert"),
# "funnel-transformer/large": (["input_ids"], 12, True, "bert"),
# "funnel-transformer/large-base": (["input_ids"], 12, True, "bert"),
# "funnel-transformer/xlarge": (["input_ids"], 12, True, "bert"),
# "funnel-transformer/xlarge-base": (["input_ids"], 12, True, "bert"),
# Layoutlm
"microsoft/layoutlm-base-uncased": (["input_ids"], 11, False, "bert"),
"microsoft/layoutlm-large-uncased": (["input_ids"], 11, False, "bert"),
@ -105,4 +129,7 @@ MODELS = {
"squeezebert/squeezebert-uncased": (["input_ids"], 11, False, "bert"),
"squeezebert/squeezebert-mnli": (["input_ids"], 11, False, "bert"),
"squeezebert/squeezebert-mnli-headless": (["input_ids"], 11, False, "bert"),
"unc-nlp/lxmert-base-uncased": (["input_ids", "visual_feats", "visual_pos"], 11, False, "bert"),
# "google/pegasus-xsum": (["input_ids"], 11, False, "bert"),
# "google/pegasus-large": (["input_ids"], 11, False, "bert"),
}

View file

@ -9,11 +9,12 @@ import numpy
import os
import torch
from pathlib import Path
from transformers import AutoConfig, AutoTokenizer, AutoModel
from transformers import AutoConfig, AutoTokenizer, AutoModel, LxmertConfig, TransfoXLConfig
from benchmark_helper import create_onnxruntime_session, Precision
from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS
from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState
from quantize_helper import QuantizeHelper
from huggingface_models import MODEL_CLASSES
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
logger = logging.getLogger(__name__)
@ -40,7 +41,7 @@ def restore_torch_functions():
torch.triu = torch_func["triu"]
def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, data_type=numpy.int64):
def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, data_type=numpy.int64):
input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type)
inputs = {'input_ids': input_ids}
@ -53,13 +54,23 @@ def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_name
segment_ids = numpy.zeros([batch_size, sequence_length], dtype=data_type)
inputs['token_type_ids'] = segment_ids
if config.is_encoder_decoder:
inputs['decoder_input_ids'] = input_ids
if isinstance(config, LxmertConfig):
inputs["visual_feats"] = numpy.random.randn(1, 1, config.visual_feat_dim).astype(numpy.float32)
inputs["visual_pos"] = numpy.random.randn(1, 1, config.visual_pos_dim).astype(numpy.float32)
if isinstance(config, TransfoXLConfig):
inputs["tf_transfo_xl_model/transformer/pos_emb/einsum/Einsum/inputs_1:0"] = numpy.zeros([config.hidden_size],
dtype=numpy.float32)
return inputs
def filter_inputs(inputs, input_names):
remaining_model_inputs = {}
for input_name in input_names:
remaining_model_inputs[input_name] = inputs[input_name]
if input_name in inputs:
remaining_model_inputs[input_name] = inputs[input_name]
return remaining_model_inputs
@ -88,7 +99,7 @@ def build_dynamic_axes(example_inputs, outputs_flatten):
return dynamic_axes, output_names
def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, fp16):
def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, fp16, output_names=None):
test_session = create_onnxruntime_session(onnx_model_path, use_gpu, enable_all_optimization=False)
if test_session is None:
logger.error(f"{onnx_model_path} is an invalid ONNX model")
@ -97,8 +108,8 @@ def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten
logger.info(f"{onnx_model_path} is a valid ONNX model")
# Compare the inference result with PyTorch or Tensorflow
example_ort_inputs = {k: t.cpu().numpy() for k, t in example_inputs.items()}
example_ort_outputs = test_session.run(None, example_ort_inputs)
example_ort_inputs = {k: t.numpy() for k, t in example_inputs.items()}
example_ort_outputs = test_session.run(output_names, example_ort_inputs)
if len(example_outputs_flatten) != len(example_ort_outputs):
logger.error(
f"Number of output tensors expected {len(example_outputs_flatten)}, got {len(example_ort_outputs)}")
@ -111,7 +122,7 @@ def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten
rtol = 5e-02 if fp16 else 1e-4
atol = 1e-01 if fp16 else 1e-4
if not numpy.allclose(example_ort_outputs[i], example_outputs_flatten[i].cpu(), rtol=rtol, atol=atol):
if not numpy.allclose(example_ort_outputs[i], example_outputs_flatten[i].cpu().numpy(), rtol=rtol, atol=atol):
logger.error(f"Output tensor {i} is not close: rtol={rtol}, atol={atol}")
return False
@ -195,7 +206,7 @@ def optimize_onnx_model(model_name, onnx_model_path, optimized_model_path, model
optimization_options=optimization_options,
use_gpu=use_gpu,
only_onnxruntime=False)
if model_type == 'bert_keras':
if model_type == 'bert_keras' or model_type == "bert_tf":
opt_model.use_dynamic_axes()
model_fusion_statistics[optimized_model_path] = opt_model.get_fused_operator_statistics()
@ -234,7 +245,7 @@ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_
if model_class_name == "GPT2ModelNoPastState":
if is_tf_model:
raise NotImplementedError("TFGPT2ModelNoPastState is currently not supported.")
return TFGPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
else:
return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
@ -279,13 +290,27 @@ def load_pt_model_from_tf(model_name):
return config, model
def validate_and_optimize_onnx(model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu,
precision, optimize_onnx, validate_onnx, use_raw_attention_mask, overwrite, config,
model_fusion_statistics, onnx_model_path, example_inputs, example_outputs_flatten):
def validate_and_optimize_onnx(model_name,
use_external_data_format,
model_type,
onnx_dir,
input_names,
use_gpu,
precision,
optimize_onnx,
validate_onnx,
use_raw_attention_mask,
overwrite,
config,
model_fusion_statistics,
onnx_model_path,
example_inputs,
example_outputs_flatten,
output_names=None):
is_valid_onnx_model = True
if validate_onnx:
is_valid_onnx_model = validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu,
False)
False, output_names)
if optimize_onnx or precision == Precision.FLOAT16 or precision == Precision.INT8: # Use script (optimizer.py) to optimize
optimized_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), True, use_gpu, precision,
@ -297,7 +322,7 @@ def validate_and_optimize_onnx(model_name, use_external_data_format, model_type,
onnx_model_path = optimized_model_path
if validate_onnx:
is_valid_onnx_model = validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu,
precision == Precision.FLOAT16)
precision == Precision.FLOAT16, output_names)
if precision == Precision.INT8:
logger.info(f"Quantizing model: {onnx_model_path}")
@ -375,46 +400,88 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
config, model = load_tf_model(model_name, model_class, cache_dir)
model._saved_model_inputs_spec = None
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
# Fix "Using pad_token, but it is not set yet" error.
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
max_input_size = tokenizer.max_model_input_sizes[
model_name] if model_name in tokenizer.max_model_input_sizes else 1024
config, model = load_tf_model(model_name, model_class, cache_dir)
model.resize_token_embeddings(len(tokenizer))
example_inputs = tokenizer.encode_plus("This is a sample input",
return_tensors="tf",
max_length=max_input_size,
pad_to_max_length=True,
padding="max_length",
truncation=True)
example_inputs = filter_inputs(example_inputs, input_names)
example_outputs = model(example_inputs, training=False).to_tuple()
if config.is_encoder_decoder:
example_inputs["decoder_input_ids"] = tokenizer.encode_plus("This is a sample input",
return_tensors="tf",
max_length=max_input_size,
padding="max_length",
truncation=True).input_ids
if model_name == "unc-nlp/lxmert-base-uncased":
example_inputs["visual_feats"] = tf.random.normal([1, 1, config.visual_feat_dim])
example_inputs["visual_pos"] = tf.random.normal([1, 1, config.visual_pos_dim])
# Flatten is needed for gpt2 and distilgpt2.
example_outputs_flatten = flatten(example_outputs)
example_outputs_flatten = update_flatten_list(example_outputs_flatten, [])
try:
# Use no past state for these models
if config.use_cache:
config.use_cache = False
except:
pass
example_outputs = model(example_inputs, training=False)
output_names = None
# For xlnet models, only compare the last_hidden_state output.
if model_name == "xlnet-base-cased" or model_name == "xlnet-large-cased":
output_names = ["last_hidden_state"]
example_outputs = example_outputs["last_hidden_state"]
# Flatten is needed for gpt2 and distilgpt2. Output name sorting is needed for tf2onnx outputs to match onnx outputs.
from tensorflow.python.util import nest
example_outputs_flatten = nest.flatten(example_outputs)
onnx_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), False, use_gpu, precision, False,
use_external_data_format)
tf_internal_model_path = onnx_model_path[:-5] if use_external_data_format else onnx_model_path
if overwrite or not os.path.exists(onnx_model_path):
if overwrite or not os.path.exists(tf_internal_model_path):
logger.info("Exporting ONNX model to {}".format(onnx_model_path))
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
if not use_external_data_format:
Path(tf_internal_model_path).parent.mkdir(parents=True, exist_ok=True)
import tf2onnx, zipfile
tf2onnx.logging.set_level(tf2onnx.logging.ERROR)
specs = []
for name, value in example_inputs.items():
dims = [None] * len(value.shape)
specs.append(tf.TensorSpec(tuple(dims), value.dtype, name=name))
_, _ = tf2onnx.convert.from_keras(model,
input_signature=tuple(specs),
opset=opset_version,
large_model=use_external_data_format,
output_path=tf_internal_model_path)
if use_external_data_format:
# need to unpack the zip for run_onnxruntime()
with zipfile.ZipFile(tf_internal_model_path, 'r') as z:
z.extractall(os.path.dirname(tf_internal_model_path))
tf_internal_model_path = os.path.join(os.path.dirname(tf_internal_model_path), "__MODEL_PROTO.onnx")
if os.path.exists(onnx_model_path):
os.remove(onnx_model_path)
os.rename(tf_internal_model_path, onnx_model_path)
import keras2onnx
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=opset_version)
keras2onnx.save_model(onnx_model, onnx_model_path)
else:
logger.info(f"Skip export since model existed: {onnx_model_path}")
model_type = model_type + '_keras'
model_type = model_type + '_tf'
onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimize_onnx,
validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path,
example_inputs, example_outputs_flatten)
example_inputs, example_outputs_flatten, output_names)
return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size

View file

@ -138,7 +138,8 @@ class BertOnnxModelKeras(BertOnnxModelTF):
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
logger.debug("Create an Attention node.")
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v,
add_q, add_k, add_v, parent.output[0],
add_q, add_k, add_v, self.num_heads,
self.hidden_size, parent.output[0],
reshape_qkv.output[0])
if attention_node is None:
continue

View file

@ -29,6 +29,21 @@ class BertOnnxModelTF(BertOnnxModel):
self.remove_nodes(nodes_to_remove)
logger.info(f"Removed Identity count: {len(nodes_to_remove)}")
def match_mask_path(self, add_or_sub_before_softmax):
mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Reshape', 'Cast'],
[1, None, 1, 0])
if mask_nodes is not None:
return mask_nodes
mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Slice', 'Unsqueeze'],
[1, 0, 1, 0, 0])
if mask_nodes is not None:
return mask_nodes
mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'],
[1, None, 1, 0, 0])
return mask_nodes
def fuse_mask(self):
nodes_to_remove = []
for node in self.nodes():
@ -85,16 +100,14 @@ class BertOnnxModelTF(BertOnnxModel):
nodes_to_remove = []
for node in self.nodes():
if node.op_type == 'Mul' and self.has_constant_input(node, -10000):
mask_path = self.match_parent_path(node, ['Sub', 'Unsqueeze', 'Mul', 'Cast', 'Reshape', 'Cast'],
[0, 1, 0, 1, 0, 0])
mask_path = self.match_parent_path(node, ['Sub', 'Cast', 'Slice', 'Unsqueeze'], [0, 1, 0, 0])
if mask_path is None:
continue
sub_node, unsqueeze_node, mul_node, cast_node_0, reshape_node_0, cast_node_1 = mask_path
sub_node, cast_node, slice_node, unsqueeze_node = mask_path
mask_input_name = self.attention_mask.get_first_mask()
if cast_node_1.input[0] != mask_input_name:
print("Cast input {} is not mask input{}".format(cast_node_1.input[0], mask_input_name))
if unsqueeze_node.input[0] != mask_input_name:
print("Cast input {} is not mask input {}".format(unsqueeze_node.input[0], mask_input_name))
continue
unsqueeze_added_1 = onnx.helper.make_node('Unsqueeze',
@ -109,13 +122,14 @@ class BertOnnxModelTF(BertOnnxModel):
name='Mask_UnSqueeze_2',
axes=[2])
#self.replace_node_input(cast_node, cast_node.input[0], 'mask_fuse_unsqueeze2_output')
cast_node_2 = onnx.helper.make_node('Cast',
inputs=['mask_fuse_unsqueeze2_output'],
outputs=['mask_fuse_cast_output'])
cast_node_2.attribute.extend([onnx.helper.make_attribute("to", 1)])
self.replace_node_input(sub_node, sub_node.input[1], 'mask_fuse_cast_output')
nodes_to_remove.extend([unsqueeze_node, mul_node, cast_node_0, reshape_node_0, cast_node_1])
nodes_to_remove.extend([slice_node, unsqueeze_node, cast_node])
self.add_node(unsqueeze_added_1)
self.add_node(unsqueeze_added_2)
self.add_node(cast_node_2)
@ -360,9 +374,22 @@ class BertOnnxModelTF(BertOnnxModel):
nodes_to_remove = []
attention_count = 0
start_nodes = []
skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
for normalize_node in skip_layer_norm_nodes:
layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
start_nodes.extend(skip_layer_norm_nodes)
start_nodes.extend(layer_norm_nodes)
for normalize_node in start_nodes:
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
if normalize_node.op_type == 'LayerNormalization':
add_before_layernorm = self.match_parent(normalize_node, 'Add', 0)
if add_before_layernorm is not None:
normalize_node = add_before_layernorm
else:
continue
parent = self.get_parent(normalize_node, 1)
if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]:
parent = self.get_parent(normalize_node, 0)
@ -376,50 +403,63 @@ class BertOnnxModelTF(BertOnnxModel):
qkv_nodes = self.match_parent_path(normalize_node, ['MatMul', 'Reshape', 'Transpose', 'MatMul'],
[1, 0, 0, 0])
if qkv_nodes is None:
logger.debug("Failed to match qkv nodes")
continue
qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'Einsum', 'Einsum'], [0, 0, 0])
if qkv_nodes is None:
logger.debug("Failed to match qkv nodes")
continue
(reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes[-3:]
matmul_qkv = qkv_nodes[-1]
v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
if v_nodes is None:
logger.debug("Failed to match v path")
continue
v_nodes = self.match_parent_path(matmul_qkv, ['Add', 'Einsum'], [1, 0])
if v_nodes is None:
logger.debug("Failed to match v path")
continue
(transpose_v, reshape_v, add_v, matmul_v) = v_nodes
add_v = v_nodes[-2]
matmul_v = v_nodes[-1]
qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', "Mul", 'MatMul'], [0, 0, 0, 0])
if qk_nodes is None:
logger.debug("Failed to match qk_paths")
continue
(softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes
qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'Einsum'], [0, 0, 0])
if qk_nodes is None:
logger.debug("Failed to match qk_paths")
continue
matmul_qk = qk_nodes[-1]
q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, 0])
if q_nodes is None:
logger.debug("Failed to match q path")
continue
(transpose_q, reshape_q, add_q, matmul_q) = q_nodes
q_nodes = self.match_parent_path(matmul_qk, ['Add', 'Einsum'], [0, 0])
if q_nodes is None:
logger.debug("Failed to match q path")
continue
add_q = q_nodes[-2]
matmul_q = q_nodes[-1]
k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
if k_nodes is None:
logger.debug("Failed to match k path")
continue
(transpose_k, reshape_k, add_k, matmul_k) = k_nodes
mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Unsqueeze'], [1, 0, 1])
if mask_nodes is None:
mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Mul'], [1, 0, 1, 0, 0])
if mask_nodes is None:
logger.debug("Failed to match mask path")
k_nodes = self.match_parent_path(matmul_qk, ['Mul', 'Add', 'Einsum'], [1, 0, 0])
if k_nodes is None:
logger.debug("Failed to match k path")
continue
add_k = k_nodes[-2]
matmul_k = k_nodes[-1]
mask_nodes = self.match_mask_path(qk_nodes[1])
if mask_nodes is None:
logger.debug("Cannot find mask_nodes.")
continue
if not self.has_constant_input(mask_nodes[1], 1):
logger.debug("Sub node expected to have an input with constant value 1.0.")
continue
# add a squeeze node to convert a 3-d mask to 2-d
squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0])
squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0]) or self.match_parent_path(
mask_nodes[-1], ['Expand'], [0])
squeeze_node_name = "Squeeze_3d_to_2d_mask"
squeeze_output_name = squeeze_node_name + "_output"
if squeeze_node is None and len(mask_nodes) == 5:
if squeeze_node is None and len(mask_nodes) == 5 and self.find_graph_input(mask_nodes[-1].input[0]) is None:
mask_input = mask_nodes[-1].input[1]
self.add_node(
helper.make_node("Squeeze", [mask_input], [squeeze_output_name], squeeze_node_name, axes=[1]))
@ -427,11 +467,32 @@ class BertOnnxModelTF(BertOnnxModel):
is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node)
if is_same_root:
mask_index = self.attention_mask.process_mask(squeeze_output_name)
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
logger.debug("Create an Attention node.")
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v,
add_q, add_k, add_v, parent.output[0],
reshape_qkv.output[0])
# For tf models, q and v are flipped.
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_k, matmul_q, matmul_v,
add_k, add_q, add_v, self.num_heads,
self.hidden_size, parent.output[0],
qkv_nodes[2].output[0])
if attention_node is None:
continue
if qkv_nodes[1].op_type == 'Einsum':
# add reshape before einsum
tensor = helper.make_tensor(name=qkv_nodes[1].name + "_newshape",
data_type=TensorProto.INT64,
dims=[4],
vals=np.int64(
[[0, 0, self.num_heads,
int(self.hidden_size / self.num_heads)]]).tobytes(),
raw=True)
self.add_initializer(tensor)
reshape_ = helper.make_node("Reshape",
inputs=[attention_node.output[0], qkv_nodes[1].name + "_newshape"],
outputs=[qkv_nodes[1].name + "_reshape_output"],
name=qkv_nodes[1].name + "_reshape")
qkv_nodes[1].input[0] = qkv_nodes[1].name + "_reshape_output"
self.add_node(reshape_)
if parent.op_type == 'Reshape':
# Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input
hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1]
@ -443,13 +504,11 @@ class BertOnnxModelTF(BertOnnxModel):
self.add_initializer(tensor)
parent.input[1] = parent.name + "_modified"
if attention_node is None:
continue
self.add_node(attention_node)
attention_count += 1
nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
nodes_to_remove.extend(qkv_nodes[2:])
nodes_to_remove.extend(qk_nodes)
nodes_to_remove.extend(q_nodes)
nodes_to_remove.extend(k_nodes)
@ -466,7 +525,20 @@ class BertOnnxModelTF(BertOnnxModel):
self.remove_identity()
self.process_embedding()
#TODO: remove fuse mask since we have embedding fused so fuse_attention shall handle the mask nodes.
self.fuse_mask()
# self.fuse_mask()
self.skip_reshape()
def skip_reshape(self):
count = 0
reshape_nodes = self.get_nodes_by_op_type("Reshape")
for reshape_node in reshape_nodes:
parent = self.get_parent(reshape_node, 0)
if parent is not None and parent.op_type == "Reshape":
reshape_node.input[0] = parent.input[0]
count += 1
if count > 0:
logger.info(f"Skip consequent Reshape count: {count}")
def remove_reshape_before_first_attention(self):
attention_nodes = self.get_nodes_by_op_type("Attention")
@ -475,7 +547,7 @@ class BertOnnxModelTF(BertOnnxModel):
if path is None:
continue
logger.info("Remove Reshape before first Attention node.")
reshape, embed = path
reshape, _ = path
self.replace_input_of_all_nodes(reshape.output[0], reshape.input[0])
self.remove_node(reshape)
break

View file

@ -39,7 +39,8 @@ MODEL_CLASSES = {
"bert": (BertOnnxModel, "pytorch", True),
"bert_tf": (BertOnnxModelTF, "tf2onnx", False),
"bert_keras": (BertOnnxModelKeras, "keras2onnx", False),
"gpt2": (Gpt2OnnxModel, "pytorch", True)
"gpt2": (Gpt2OnnxModel, "pytorch", True),
"gpt2_tf": (Gpt2OnnxModel, 'tf2onnx', False) # might add a class for GPT2OnnxModel for TF later.
}

View file

@ -127,7 +127,107 @@ def create_bert_attention(input_hidden_size=16, pruned_num_heads=2, pruned_head_
model = helper.make_model(graph)
return model
def create_tf2onnx_attention_3d(input_hidden_size=16, num_heads=4, head_size=4, use_float_mask=False):
# unsqueeze in opset version 13 has two inputs (axis is moved from attribute to input).
has_unsqueeze_two_inputs = (version.parse(onnx.__version__) >= version.parse('1.8.0'))
# nodes in attention subgraph
nodes = [
helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm"),
helper.make_node("LayerNormalization", ["layernorm_input", "layer_norm_weight", "layer_norm_bias"],
["layernorm_out"],
"layernorm",
axis=-1,
epsion=0.000009999999747378752),
# q nodes
helper.make_node("Einsum", ["layernorm_out", "einsum_q_weight"], ["einsum_q_out"], "einsum_q", equation="abc,cde->abde"),
helper.make_node("Add", ["einsum_q_out", "add_q_weight"], ["add_q_out"], "add_q"),
# k nodes
helper.make_node("Einsum", ["layernorm_out", "einsum_k_weight"], ["einsum_k_out"], "einsum_k", equation="abc,cde->abde"),
helper.make_node("Add", ["einsum_k_out", "add_k_weight"], ["add_k_out"], "add_k"),
helper.make_node("Mul", ["add_k_out", "mul_weight_1"], ["mul_k_out"], "mul_k"),
# mask nodes
helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0") if has_unsqueeze_two_inputs \
else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1, 2]),
helper.make_node("Slice", ["unsqueeze0_out", "slice_start", "slice_end", "slice_axes", "slice_steps"], ["slice_out"], "slice"),
# when attention_mask is float type, no need to cast
helper.make_node("Cast", ["slice_out"], ["cast_out"], "cast", to=1) if not use_float_mask else None,
helper.make_node("Sub", ["sub_weight", "unsqueeze1_out" if use_float_mask else "cast_out"], ["sub_out"], "sub"),
helper.make_node("Mul", ["sub_out", "mul_weight_2"], ["mul_mask_out"], "mul_mask"),
# qk nodes
helper.make_node("Einsum", ["add_q_out", "mul_k_out"], ["einsum_qk_out"], "einsum_qk", equation="aecd,abcd->acbe"),
helper.make_node("Add", ["einsum_qk_out", "mul_mask_out"], ["add_qk_out"], "add_qk"),
helper.make_node("Softmax", ["add_qk_out"], ["softmax_qk_out"], "softmax_qk", axis=3),
# v nodes
helper.make_node("Einsum", ["layernorm_out", "einsum_v_weight"], ["einsum_v_out"], "einsum_v", equation="abc,cde->abde"),
helper.make_node("Add", ["einsum_v_out", "add_v_weight"], ["add_v_out"], "add_v"),
# qkv nodes
helper.make_node("Einsum", ["softmax_qk_out", "add_v_out"], ["einsum_qkv_1_out"], "einsum_qkv_1", equation="acbe,aecd->abcd"),
helper.make_node("Einsum", ["einsum_qkv_1_out", "einsum_qkv_weight"], ["einsum_qkv_2_out"], "einsum_qkv_2", equation="abcd,cde->abe"),
helper.make_node("Add", ["einsum_qkv_2_out", "add_qkv_weight"], ["add_qkv_out"], "add_qkv"),
helper.make_node("Add", ["add_qkv_out", "layernorm_out"], ["skip_output"], "add_skip"),
helper.make_node("LayerNormalization", ["skip_output", "layer_norm_weight", "layer_norm_bias"], ["output"],
"layernorm2",
axis=-1,
epsion=0.000009999999747378752),
]
initializers = [ # initializers
float_tensor('layer_norm_weight', [input_hidden_size]),
float_tensor('layer_norm_bias', [input_hidden_size]),
float_tensor('einsum_q_weight', [input_hidden_size, num_heads, head_size]),
float_tensor('einsum_k_weight', [input_hidden_size, num_heads, head_size]),
float_tensor('einsum_v_weight', [input_hidden_size, num_heads, head_size]),
float_tensor('einsum_qkv_weight', [num_heads, head_size, input_hidden_size]),
float_tensor('add_q_weight', [num_heads, head_size]),
float_tensor('add_k_weight', [num_heads, head_size]),
float_tensor('add_v_weight', [num_heads, head_size]),
float_tensor('add_qkv_weight', [input_hidden_size]),
helper.make_tensor('sub_weight', TensorProto.FLOAT, [1], [1.0]),
helper.make_tensor('mul_weight_1', TensorProto.FLOAT, [1], [-10000]),
helper.make_tensor('mul_weight_2', TensorProto.FLOAT, [1], [0.125]),
helper.make_tensor('reshape_weight_1', TensorProto.INT64, [4], [0, 0, num_heads, head_size]),
helper.make_tensor('slice_start', TensorProto.INT32, [4], [0, 0, 0, 0]),
helper.make_tensor('slice_end', TensorProto.INT32, [4], [1000000000, 1000000000, 1000000000, 1000000000]),
helper.make_tensor('slice_axes', TensorProto.INT32, [4], [0, 1, 2, 3]),
helper.make_tensor('slice_steps', TensorProto.INT32, [4], [1, 1, 1, 1])
]
if has_unsqueeze_two_inputs:
initializers.append(helper.make_tensor('axes_1', TensorProto.INT64, [2], [1, 2]))
batch_size = 1
sequence_length = 3
graph = helper.make_graph(
[node for node in nodes if node],
"AttentionFusionPrunedModel", #name
[ # inputs
helper.make_tensor_value_info('input_1', TensorProto.FLOAT,
[batch_size, sequence_length, input_hidden_size]),
helper.make_tensor_value_info('input_2', TensorProto.FLOAT,
[batch_size, sequence_length, input_hidden_size]),
helper.make_tensor_value_info('input_mask', TensorProto.FLOAT if use_float_mask else TensorProto.INT64,
[batch_size, sequence_length])
],
[ # outputs
helper.make_tensor_value_info('output', TensorProto.FLOAT,
[batch_size, sequence_length, input_hidden_size]),
],
initializers)
model = helper.make_model(graph)
return model
if __name__ == "__main__":
model = create_bert_attention()
onnx.save(model, "pruned_bert_attention.onnx")
onnx.save(model, "pruned_bert_attention.onnx")
model = create_tf2onnx_attention_3d()
onnx.save(model, "bert_3d_attention.onnx")

View file

@ -8,7 +8,7 @@ import unittest
import os
import sys
import onnx
from bert_model_generator import create_bert_attention
from bert_model_generator import create_bert_attention, create_tf2onnx_attention_3d
# set path so that we could import from parent directory
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@ -28,6 +28,19 @@ class TestFusion(unittest.TestCase):
'pruned_attention_opt.onnx')
expected = onnx.load(expected_model_path)
self.assertEqual(str(optimized_model.model.graph), str(expected.graph))
def test_3d_attention_fusion_tf2onnx_model(self):
model = create_tf2onnx_attention_3d()
dir = '.'
model_path = os.path.join(dir, 'bert_3d_attention.onnx')
onnx.save(model, model_path)
optimized_model = optimize_model(model_path, model_type='bert_tf', num_heads=4, hidden_size=16)
os.remove(model_path)
expected_model_path = os.path.join(os.path.dirname(__file__), 'test_data', 'fusion',
'bert_3d_attention_opt.onnx')
expected = onnx.load(expected_model_path)
self.assertEqual(str(optimized_model.model.graph), str(expected.graph))
if __name__ == '__main__':

View file

@ -340,11 +340,34 @@ class TestBertOptimization(unittest.TestCase):
self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30])
@pytest.mark.slow
def test_bert_base_cased_from_tf(self):
self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 1)
self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 2)
self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 3)
def test_huggingface_bert_base_cased_from_tf2onnx(self):
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1)
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 2)
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 3)
@pytest.mark.slow
def test_huggingface_distilgpt2_from_tf2onnx(self):
self._test_optimizer_on_tf_model("distilgpt2", [0, 0, 0, 0, 0, 12, 1], 1)
@pytest.mark.slow
def test_huggingface_albert_from_tf2onnx(self):
self._test_optimizer_on_tf_model("albert-base-v1", [0, 0, 0, 0, 0, 0, 25], 1)
@pytest.mark.slow
def test_huggingface_gpt2_from_tf2onnx(self):
self._test_optimizer_on_tf_model("gpt2", [0, 0, 0, 0, 0, 24, 1], 1, validate_model=False)
@pytest.mark.slow
def test_huggingface_roberta_from_tf2onnx(self):
self._test_optimizer_on_tf_model("roberta-base", [0, 12, 0, 0, 0, 0, 25], 1, validate_model=False)
@pytest.mark.slow
def test_huggingface_distilbert_from_tf2onnx(self):
self._test_optimizer_on_tf_model("distilbert-base-uncased", [0, 0, 0, 0, 0, 0, 13], 1, validate_model=False)
@pytest.mark.slow
def test_huggingface_xlm_from_tf2onnx(self):
self._test_optimizer_on_tf_model("xlm-mlm-ende-1024", [0, 0, 0, 0, 0, 1, 12], 1, validate_model=False)
if __name__ == '__main__':
unittest.main()