mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
Add AutoModel selector in transformers tool (#5051)
* Add AutoModel selector in transformers tool * change distilbert-*-squad's pipeline to AutoModelForQuestionAnswering * rule base selector and add model_class as parameter * Update huggingface_models.py * review comments
This commit is contained in:
parent
4553b2eecd
commit
b23e08b85c
3 changed files with 59 additions and 18 deletions
|
|
@ -56,7 +56,7 @@ from onnx_exporter import create_onnxruntime_input, load_pretrained_model, expor
|
|||
|
||||
logger = logging.getLogger('')
|
||||
|
||||
from huggingface_models import MODELS
|
||||
from huggingface_models import MODELS, MODEL_CLASSES
|
||||
|
||||
cpu_count = psutil.cpu_count(logical=True)
|
||||
# Set OMP environment variable before importing onnxruntime or torch.
|
||||
|
|
@ -66,7 +66,7 @@ if "OMP_NUM_THREADS" not in os.environ:
|
|||
import torch
|
||||
from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model)
|
||||
|
||||
def run_onnxruntime(use_gpu, model_names, precision, batch_sizes, sequence_lengths, repeat_times, input_counts,
|
||||
def run_onnxruntime(use_gpu, model_names, model_class, precision, batch_sizes, sequence_lengths, repeat_times, input_counts,
|
||||
optimize_onnx, validate_onnx, cache_dir, onnx_dir, verbose, overwrite, disable_ort_io_binding,
|
||||
use_raw_attention_mask, thread_num, model_fusion_statistics):
|
||||
import onnxruntime
|
||||
|
|
@ -91,9 +91,9 @@ def run_onnxruntime(use_gpu, model_names, precision, batch_sizes, sequence_lengt
|
|||
|
||||
with torch.no_grad():
|
||||
onnx_model_file, is_valid_onnx_model, vocab_size, max_sequence_length = export_onnx_model(
|
||||
model_name, MODELS[model_name][1], MODELS[model_name][2], MODELS[model_name][3], cache_dir,
|
||||
onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx, use_raw_attention_mask,
|
||||
overwrite, model_fusion_statistics)
|
||||
model_name, MODELS[model_name][1], MODELS[model_name][2], MODELS[model_name][3], model_class,
|
||||
cache_dir, onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx,
|
||||
use_raw_attention_mask, overwrite, model_fusion_statistics)
|
||||
if not is_valid_onnx_model:
|
||||
continue
|
||||
|
||||
|
|
@ -154,7 +154,7 @@ def run_onnxruntime(use_gpu, model_names, precision, batch_sizes, sequence_lengt
|
|||
return results
|
||||
|
||||
|
||||
def run_pytorch(use_gpu, model_names, precision, batch_sizes, sequence_lengths, repeat_times, torchscript, cache_dir,
|
||||
def run_pytorch(use_gpu, model_names, model_class, precision, batch_sizes, sequence_lengths, repeat_times, torchscript, cache_dir,
|
||||
verbose):
|
||||
results = []
|
||||
if use_gpu and not torch.cuda.is_available():
|
||||
|
|
@ -165,7 +165,7 @@ def run_pytorch(use_gpu, model_names, precision, batch_sizes, sequence_lengths,
|
|||
|
||||
for model_name in model_names:
|
||||
config = AutoConfig.from_pretrained(model_name, torchscript=torchscript, cache_dir=cache_dir)
|
||||
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir)
|
||||
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
|
||||
max_input_size = tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024
|
||||
|
|
@ -237,6 +237,13 @@ def parse_arguments():
|
|||
choices=list(MODELS.keys()),
|
||||
help="Pre-trained models in the list: " + ", ".join(MODELS.keys()))
|
||||
|
||||
parser.add_argument('--model_class',
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
choices=list(MODEL_CLASSES.keys()),
|
||||
help='Model type selected in the list: ' + ', '.join(MODEL_CLASSES.keys()))
|
||||
|
||||
parser.add_argument("-e",
|
||||
"--engines",
|
||||
required=False,
|
||||
|
|
@ -358,18 +365,18 @@ def main():
|
|||
logger.warning("--input_counts is not implemented for torch or torchscript engine.")
|
||||
|
||||
if enable_torchscript:
|
||||
results += run_pytorch(args.use_gpu, args.models, args.precision, args.batch_sizes, args.sequence_lengths,
|
||||
results += run_pytorch(args.use_gpu, args.models, args.model_class, args.precision, args.batch_sizes, args.sequence_lengths,
|
||||
args.test_times, True, args.cache_dir, args.verbose)
|
||||
|
||||
if enable_torch:
|
||||
results += run_pytorch(args.use_gpu, args.models, args.precision, args.batch_sizes, args.sequence_lengths,
|
||||
results += run_pytorch(args.use_gpu, args.models, args.model_class, args.precision, args.batch_sizes, args.sequence_lengths,
|
||||
args.test_times, False, args.cache_dir, args.verbose)
|
||||
|
||||
model_fusion_statistics = {}
|
||||
if enable_onnxruntime:
|
||||
try:
|
||||
use_raw_attention_mask = True
|
||||
results += run_onnxruntime(args.use_gpu, args.models, args.precision, args.batch_sizes,
|
||||
results += run_onnxruntime(args.use_gpu, args.models, args.model_class, args.precision, args.batch_sizes,
|
||||
args.sequence_lengths, args.test_times, args.input_counts, args.optimize_onnx,
|
||||
args.validate_onnx, args.cache_dir, args.onnx_dir, args.verbose, args.overwrite,
|
||||
args.disable_ort_io_binding, use_raw_attention_mask, args.thread_num,
|
||||
|
|
|
|||
|
|
@ -4,8 +4,21 @@
|
|||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from transformers import AutoModelForQuestionAnswering
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import AutoModelWithLMHead
|
||||
from transformers import AutoModel
|
||||
|
||||
# Maps model class name to a tuple of model class
|
||||
MODEL_CLASSES = {
|
||||
'AutoModel': AutoModel,
|
||||
'AutoModelWithLMHead': AutoModelWithLMHead,
|
||||
'AutoModelForSequenceClassification': AutoModelForSequenceClassification,
|
||||
'AutoModelForQuestionAnswering': AutoModelForQuestionAnswering
|
||||
}
|
||||
|
||||
# List of pretrained models: https://huggingface.co/transformers/pretrained_models.html
|
||||
# Pretrained model name to a tuple of input names, opset_version, use_external_data_format and optimization model type
|
||||
# Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type
|
||||
MODELS = {
|
||||
# BERT
|
||||
"bert-base-uncased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
|
||||
|
|
@ -57,7 +70,7 @@ MODELS = {
|
|||
"roberta-large-openai-detector": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
# DistilBERT
|
||||
"distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"distilbert-base-uncased-distilled-squad": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"distilbert-base-uncased-distilled-squad": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"distilbert-base-cased": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"distilbert-base-cased-distilled-squad": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
"distilbert-base-german-cased": (["input_ids", "attention_mask"], 11, False, "bert"),
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from transformers import AutoConfig, AutoTokenizer, AutoModel
|
|||
from benchmark_helper import create_onnxruntime_session, Precision
|
||||
from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS
|
||||
from quantize_helper import QuantizeHelper
|
||||
from huggingface_models import MODEL_CLASSES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -182,20 +183,40 @@ def optimize_onnx_model(onnx_model_path, optimized_model_path, model_type, num_a
|
|||
logger.info(f"Skip optimization since model existed: {optimized_model_path}")
|
||||
|
||||
|
||||
def load_pretrained_model(model_name, config, cache_dir):
|
||||
def modelclass_dispatcher(model_name, custom_model_class):
|
||||
if (custom_model_class != None):
|
||||
return MODEL_CLASSES[custom_model_class]
|
||||
|
||||
if model_name in PRETRAINED_GPT2_MODELS:
|
||||
return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
||||
return AutoModel.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
||||
return GPT2ModelNoPastState
|
||||
|
||||
import re
|
||||
if (re.search('-squad$', model_name) != None):
|
||||
from transformers import AutoModelForQuestionAnswering
|
||||
return AutoModelForQuestionAnswering
|
||||
elif (re.search('-mprc$', model_name) != None):
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
return AutoModelForSequenceClassification
|
||||
elif (re.search('gpt2', model_name) != None):
|
||||
from transformers import AutoModelWithLMHead
|
||||
return AutoModelWithLMHead
|
||||
|
||||
return AutoModel
|
||||
|
||||
|
||||
def export_onnx_model(model_name, opset_version, use_external_data_format, model_type, cache_dir, onnx_dir, input_names,
|
||||
use_gpu, precision, optimize_onnx, validate_onnx, use_raw_attention_mask, overwrite,
|
||||
def load_pretrained_model(model_name, config, cache_dir, custom_model_class):
|
||||
model_class = modelclass_dispatcher(model_name, custom_model_class)
|
||||
return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
||||
|
||||
|
||||
def export_onnx_model(model_name, opset_version, use_external_data_format, model_type, model_class, cache_dir, onnx_dir,
|
||||
input_names, use_gpu, precision, optimize_onnx, validate_onnx, use_raw_attention_mask, overwrite,
|
||||
model_fusion_statistics):
|
||||
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
if hasattr(config, 'return_dict'):
|
||||
config.return_dict = False
|
||||
|
||||
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir)
|
||||
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class)
|
||||
model.cpu()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
|
|
|
|||
Loading…
Reference in a new issue