mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
add tool for generating test data for longformer (#6415)
This commit is contained in:
parent
76dbd88526
commit
afd7b8b3f7
1 changed files with 241 additions and 0 deletions
|
|
@ -0,0 +1,241 @@
|
|||
#-------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
# Generate test data for a longformer model, so that we can use onnxruntime_perf_test.exe to evaluate the inference latency.
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
from onnx_model import OnnxModel
|
||||
from bert_test_data import fake_input_ids_data, fake_input_mask_data, output_test_data
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--model', required=True, type=str, help="bert onnx model path.")
|
||||
|
||||
parser.add_argument('--output_dir',
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="output test data path. If not specified, .")
|
||||
|
||||
parser.add_argument('--batch_size', required=False, type=int, default=1, help="batch size of input")
|
||||
|
||||
parser.add_argument('--sequence_length',
|
||||
required=False,
|
||||
type=int,
|
||||
default=128,
|
||||
help="maximum sequence length of input")
|
||||
|
||||
parser.add_argument('--global_tokens', required=False, type=int, default=10, help="number of global tokens")
|
||||
|
||||
parser.add_argument('--input_ids_name', required=False, type=str, default=None, help="input name for input ids")
|
||||
|
||||
parser.add_argument('--input_mask_name',
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="input name for attention mask")
|
||||
|
||||
parser.add_argument('--global_mask_name',
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="input name for global attention mask")
|
||||
|
||||
parser.add_argument('--samples', required=False, type=int, default=1, help="number of test cases to be generated")
|
||||
|
||||
parser.add_argument('--seed', required=False, type=int, default=3, help="random seed")
|
||||
|
||||
parser.add_argument('--verbose', required=False, action='store_true', help="print verbose information")
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_longformer_inputs(onnx_file, input_ids_name=None, input_mask_name=None, global_mask_name=None):
|
||||
"""
|
||||
Get graph inputs for longformer model.
|
||||
"""
|
||||
model = ModelProto()
|
||||
with open(onnx_file, "rb") as f:
|
||||
model.ParseFromString(f.read())
|
||||
|
||||
onnx_model = OnnxModel(model)
|
||||
graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
|
||||
|
||||
if input_ids_name is not None:
|
||||
input_ids = onnx_model.find_graph_input(input_ids_name)
|
||||
if input_ids is None:
|
||||
raise ValueError(f"Graph does not have input named {input_ids_name}")
|
||||
|
||||
input_mask = None
|
||||
if input_mask_name:
|
||||
input_mask = onnx_model.find_graph_input(input_mask_name)
|
||||
if input_mask is None:
|
||||
raise ValueError(f"Graph does not have input named {input_mask_name}")
|
||||
|
||||
global_mask = None
|
||||
if global_mask_name:
|
||||
global_mask = onnx_model.find_graph_input(global_mask_name)
|
||||
if global_mask is None:
|
||||
raise ValueError(f"Graph does not have input named {global_mask_name}")
|
||||
|
||||
expected_inputs = 1 + (1 if input_mask else 0) + (1 if global_mask else 0)
|
||||
if len(graph_inputs) != expected_inputs:
|
||||
raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
|
||||
|
||||
return input_ids, input_mask, global_mask
|
||||
|
||||
if len(graph_inputs) != 3:
|
||||
raise ValueError("Expect the graph to have 3 inputs. Got {}".format(len(graph_inputs)))
|
||||
|
||||
# Try guess the inputs based on naming.
|
||||
input_ids = None
|
||||
input_mask = None
|
||||
global_mask = None
|
||||
for input in graph_inputs:
|
||||
input_name_lower = input.name.lower()
|
||||
if "global" in input_name_lower:
|
||||
global_mask = input
|
||||
elif "mask" in input_name_lower:
|
||||
input_mask = input
|
||||
else:
|
||||
input_ids = input
|
||||
|
||||
if input_ids and input_mask and global_mask:
|
||||
return input_ids, input_mask, global_mask
|
||||
|
||||
raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
|
||||
|
||||
|
||||
def fake_global_mask_data(global_mask, batch_size, sequence_length, num_global_tokens):
|
||||
"""
|
||||
Fake data based on the graph input of segment_ids.
|
||||
Args:
|
||||
segment_ids (TensorProto): graph input of input tensor.
|
||||
Returns:
|
||||
data (np.array): the data for input tensor
|
||||
"""
|
||||
data_type = global_mask.type.tensor_type.elem_type
|
||||
assert data_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
|
||||
|
||||
if num_global_tokens > 0:
|
||||
assert num_global_tokens <= sequence_length
|
||||
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
|
||||
temp = np.ones((batch_size, num_global_tokens), dtype=np.int32)
|
||||
data[:temp.shape[0], :temp.shape[1]] = temp
|
||||
else:
|
||||
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
|
||||
|
||||
if data_type == TensorProto.FLOAT:
|
||||
data = np.float32(data)
|
||||
elif data_type == TensorProto.INT64:
|
||||
data = np.int64(data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def fake_test_data(batch_size,
|
||||
sequence_length,
|
||||
test_cases,
|
||||
dictionary_size,
|
||||
verbose,
|
||||
random_seed,
|
||||
input_ids,
|
||||
input_mask,
|
||||
global_mask,
|
||||
num_global_tokens,
|
||||
random_mask_length=False):
|
||||
"""
|
||||
Generate fake input data for test.
|
||||
"""
|
||||
assert input_ids is not None
|
||||
|
||||
np.random.seed(random_seed)
|
||||
random.seed(random_seed)
|
||||
|
||||
all_inputs = []
|
||||
for _ in range(test_cases):
|
||||
input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
|
||||
inputs = {input_ids.name: input_1}
|
||||
|
||||
if input_mask:
|
||||
inputs[input_mask.name] = fake_input_mask_data(input_mask, batch_size, sequence_length, random_mask_length)
|
||||
|
||||
if global_mask:
|
||||
inputs[global_mask.name] = fake_global_mask_data(global_mask, batch_size, sequence_length,
|
||||
num_global_tokens)
|
||||
|
||||
if verbose and len(all_inputs) == 0:
|
||||
print("Example inputs", inputs)
|
||||
all_inputs.append(inputs)
|
||||
|
||||
return all_inputs
|
||||
|
||||
|
||||
def generate_test_data(batch_size,
|
||||
sequence_length,
|
||||
test_cases,
|
||||
seed,
|
||||
verbose,
|
||||
input_ids,
|
||||
input_mask,
|
||||
global_mask,
|
||||
num_global_tokens,
|
||||
random_mask_length=False):
|
||||
dictionary_size = 10000
|
||||
all_inputs = fake_test_data(batch_size, sequence_length, test_cases, dictionary_size, verbose, seed, input_ids,
|
||||
input_mask, global_mask, num_global_tokens, random_mask_length)
|
||||
if len(all_inputs) != test_cases:
|
||||
print("Failed to create test data for test.")
|
||||
return all_inputs
|
||||
|
||||
|
||||
def create_longformer_test_data(model, output_dir, batch_size, sequence_length, test_cases, seed, verbose,
|
||||
input_ids_name, input_mask_name, global_mask_name, num_global_tokens):
|
||||
|
||||
input_ids, input_mask, global_mask = get_longformer_inputs(model, input_ids_name, input_mask_name, global_mask_name)
|
||||
all_inputs = generate_test_data(batch_size, sequence_length, test_cases, seed, verbose, input_ids, input_mask,
|
||||
global_mask, num_global_tokens)
|
||||
|
||||
for i, inputs in enumerate(all_inputs):
|
||||
output_test_data(output_dir, i, inputs)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
output_dir = args.output_dir
|
||||
if output_dir is None:
|
||||
# Default output directory is a sub-directory under the directory of model.
|
||||
output_dir = os.path.join(
|
||||
Path(args.model).parent, "b{}_s{}_g{}".format(args.batch_size, args.sequence_length, args.global_tokens))
|
||||
|
||||
if output_dir is not None:
|
||||
# create the output directory if not existed
|
||||
path = Path(output_dir)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
print("Directory existed. test data files will be overwritten.")
|
||||
|
||||
create_longformer_test_data(args.model, output_dir, args.batch_size, args.sequence_length, args.samples, args.seed,
|
||||
args.verbose, args.input_ids_name, args.input_mask_name, args.global_mask_name,
|
||||
args.global_tokens)
|
||||
|
||||
print("Test data is saved to directory:", output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in a new issue