mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
fix longformer benchmark io_binding output_buffers (#6345)
* fix longformer benchmark io_binding output_buffers * format * import benchmark_helper from parent directory.
This commit is contained in:
parent
ea6789b754
commit
5d9552cc8b
1 changed files with 10 additions and 8 deletions
|
|
@ -11,6 +11,8 @@ import timeit
|
|||
from datetime import datetime
|
||||
import csv
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import onnxruntime
|
||||
|
||||
|
|
@ -22,9 +24,10 @@ MODELS = {
|
|||
|
||||
is_debug = False
|
||||
|
||||
# Run onnx model with ORT
|
||||
from onnxruntime.transformers import benchmark_helper
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
# Run onnx model with ORT
|
||||
import benchmark_helper
|
||||
|
||||
def get_dummy_inputs(sequence_length, num_global_tokens, device):
|
||||
# Create dummy inputs
|
||||
|
|
@ -152,6 +155,9 @@ def test_onnxruntime(device,
|
|||
"datetime": str(datetime.now()),
|
||||
}
|
||||
|
||||
max_last_state_size = max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size
|
||||
max_pooler_size = max(batch_sizes) * max(sequence_lengths)
|
||||
|
||||
result = benchmark_helper.inference_ort_with_io_binding(
|
||||
ort_session,
|
||||
ort_inputs,
|
||||
|
|
@ -159,12 +165,8 @@ def test_onnxruntime(device,
|
|||
repeat_times=test_times,
|
||||
ort_output_names=["last_state", "pooler"],
|
||||
ort_outputs=ort_outputs,
|
||||
output_buffers={
|
||||
"last_state": None,
|
||||
"pooler": None
|
||||
},
|
||||
max_last_state_size=max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size,
|
||||
max_pooler_size=max(batch_sizes) * max(sequence_lengths),
|
||||
output_buffers=[],
|
||||
output_buffer_max_sizes=[max_last_state_size, max_pooler_size],
|
||||
batch_size=batch_size,
|
||||
device=device)
|
||||
print(result)
|
||||
|
|
|
|||
Loading…
Reference in a new issue