diff --git a/onnxruntime/python/tools/transformers/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/longformer/benchmark_longformer.py index 18d3973cc5..fe0278a28c 100644 --- a/onnxruntime/python/tools/transformers/longformer/benchmark_longformer.py +++ b/onnxruntime/python/tools/transformers/longformer/benchmark_longformer.py @@ -11,6 +11,8 @@ import timeit from datetime import datetime import csv import argparse +import os +import sys import torch import onnxruntime @@ -22,9 +24,10 @@ MODELS = { is_debug = False -# Run onnx model with ORT -from onnxruntime.transformers import benchmark_helper +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +# Run onnx model with ORT +import benchmark_helper def get_dummy_inputs(sequence_length, num_global_tokens, device): # Create dummy inputs @@ -152,6 +155,9 @@ def test_onnxruntime(device, "datetime": str(datetime.now()), } + max_last_state_size = max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size + max_pooler_size = max(batch_sizes) * max(sequence_lengths) + result = benchmark_helper.inference_ort_with_io_binding( ort_session, ort_inputs, @@ -159,12 +165,8 @@ def test_onnxruntime(device, repeat_times=test_times, ort_output_names=["last_state", "pooler"], ort_outputs=ort_outputs, - output_buffers={ - "last_state": None, - "pooler": None - }, - max_last_state_size=max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size, - max_pooler_size=max(batch_sizes) * max(sequence_lengths), + output_buffers=[], + output_buffer_max_sizes=[max_last_state_size, max_pooler_size], batch_size=batch_size, device=device) print(result)