fix longformer benchmark io_binding output_buffers (#6345)

* fix longformer benchmark io_binding output_buffers

* format

* import benchmark_helper from parent directory.
This commit is contained in:
Ye Wang 2021-01-14 11:29:31 -08:00 committed by GitHub
parent ea6789b754
commit 5d9552cc8b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -11,6 +11,8 @@ import timeit
from datetime import datetime
import csv
import argparse
import os
import sys
import torch
import onnxruntime
@ -22,9 +24,10 @@ MODELS = {
is_debug = False
# Run onnx model with ORT
from onnxruntime.transformers import benchmark_helper
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
# Run onnx model with ORT
import benchmark_helper
def get_dummy_inputs(sequence_length, num_global_tokens, device):
# Create dummy inputs
@ -152,6 +155,9 @@ def test_onnxruntime(device,
"datetime": str(datetime.now()),
}
max_last_state_size = max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size
max_pooler_size = max(batch_sizes) * max(sequence_lengths)
result = benchmark_helper.inference_ort_with_io_binding(
ort_session,
ort_inputs,
@ -159,12 +165,8 @@ def test_onnxruntime(device,
repeat_times=test_times,
ort_output_names=["last_state", "pooler"],
ort_outputs=ort_outputs,
output_buffers={
"last_state": None,
"pooler": None
},
max_last_state_size=max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size,
max_pooler_size=max(batch_sizes) * max(sequence_lengths),
output_buffers=[],
output_buffer_max_sizes=[max_last_state_size, max_pooler_size],
batch_size=batch_size,
device=device)
print(result)