mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Support T5 benchmarking in transformers tool (#5133)
* init checkin * review comments * modify according to transformers release
This commit is contained in:
parent
9ec1ed42a8
commit
1a12f510fc
3 changed files with 18 additions and 8 deletions
|
|
@ -286,9 +286,17 @@ def run_tensorflow(use_gpu, model_names, model_class, precision, batch_sizes, se
|
|||
input_ids = tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
|
||||
|
||||
try:
|
||||
model(input_ids, training=False)
|
||||
def encoder_forward():
|
||||
return model(input_ids, training=False)
|
||||
|
||||
runtimes = timeit.repeat(lambda: model(input_ids, training=False), repeat=repeat_times, number=1)
|
||||
def encoder_decoder_forward():
|
||||
return model(input_ids, decoder_input_ids=input_ids, training=False)
|
||||
|
||||
inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
|
||||
|
||||
inference()
|
||||
|
||||
runtimes = timeit.repeat(lambda: inference(), repeat=repeat_times, number=1)
|
||||
|
||||
result = {
|
||||
"engine": "tensorflow",
|
||||
|
|
|
|||
|
|
@ -89,11 +89,11 @@ MODELS = {
|
|||
"albert-xlarge-v2": (["input_ids"], 12, True, "bert"),
|
||||
"albert-xxlarge-v2": (["input_ids"], 12, True, "bert"),
|
||||
# T5
|
||||
#"t5-small": (["input_ids"], 11, False, "bert"),
|
||||
#"t5-base": (["input_ids"], 11, False, "bert"),
|
||||
#"t5-large": (["input_ids"], 11, False, "bert"),
|
||||
#"t5-3b": (["input_ids"], 11, False, "bert"),
|
||||
#"t5-11b": (["input_ids"], 11, False, "bert"),
|
||||
"t5-small": (["input_ids"], 12, False, "bert"),
|
||||
"t5-base": (["input_ids"], 12, False, "bert"),
|
||||
"t5-large": (["input_ids"], 12, True, "bert"),
|
||||
"t5-3b": (["input_ids"], 12, True, "bert"),
|
||||
"t5-11b": (["input_ids"], 12, True, "bert"),
|
||||
# XLM-RoBERTa
|
||||
"xlm-roberta-base": (["input_ids"], 11, False, "bert"),
|
||||
"xlm-roberta-large": (["input_ids"], 11, True, "bert"),
|
||||
|
|
|
|||
|
|
@ -226,7 +226,9 @@ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_
|
|||
transformers_module = __import__("transformers", fromlist=[model_class_name])
|
||||
model_class = getattr(transformers_module, model_class_name)
|
||||
|
||||
return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
||||
use_cdn = False if model_name == 't5-11b' else True
|
||||
|
||||
return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir, use_cdn=use_cdn)
|
||||
|
||||
|
||||
def validate_and_optimize_onnx(model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu,
|
||||
|
|
|
|||
Loading…
Reference in a new issue