Support T5 benchmarking in transformers tool (#5133)

* init checkin

* review comments

* modify according to transformers release
This commit is contained in:
Ye Wang 2020-09-29 22:58:28 -07:00 committed by GitHub
parent 9ec1ed42a8
commit 1a12f510fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 8 deletions

View file

@ -286,9 +286,17 @@ def run_tensorflow(use_gpu, model_names, model_class, precision, batch_sizes, se
input_ids = tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
try:
model(input_ids, training=False)
def encoder_forward():
return model(input_ids, training=False)
runtimes = timeit.repeat(lambda: model(input_ids, training=False), repeat=repeat_times, number=1)
def encoder_decoder_forward():
return model(input_ids, decoder_input_ids=input_ids, training=False)
inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
inference()
runtimes = timeit.repeat(lambda: inference(), repeat=repeat_times, number=1)
result = {
"engine": "tensorflow",

View file

@ -89,11 +89,11 @@ MODELS = {
"albert-xlarge-v2": (["input_ids"], 12, True, "bert"),
"albert-xxlarge-v2": (["input_ids"], 12, True, "bert"),
# T5
#"t5-small": (["input_ids"], 11, False, "bert"),
#"t5-base": (["input_ids"], 11, False, "bert"),
#"t5-large": (["input_ids"], 11, False, "bert"),
#"t5-3b": (["input_ids"], 11, False, "bert"),
#"t5-11b": (["input_ids"], 11, False, "bert"),
"t5-small": (["input_ids"], 12, False, "bert"),
"t5-base": (["input_ids"], 12, False, "bert"),
"t5-large": (["input_ids"], 12, True, "bert"),
"t5-3b": (["input_ids"], 12, True, "bert"),
"t5-11b": (["input_ids"], 12, True, "bert"),
# XLM-RoBERTa
"xlm-roberta-base": (["input_ids"], 11, False, "bert"),
"xlm-roberta-large": (["input_ids"], 11, True, "bert"),

View file

@ -226,7 +226,9 @@ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_
transformers_module = __import__("transformers", fromlist=[model_class_name])
model_class = getattr(transformers_module, model_class_name)
return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir)
use_cdn = False if model_name == 't5-11b' else True
return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir, use_cdn=use_cdn)
def validate_and_optimize_onnx(model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu,