Benchmark With IO Binding (#4206)

* add io binding to benchmark.py
This commit is contained in:
Cecilia Liu 2020-06-15 10:06:33 -07:00 committed by GitHub
parent b4b1c6440a
commit 0b5bbb16b8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 174 additions and 83 deletions

View file

@ -222,7 +222,7 @@ def get_onnx_file_path(onnx_dir: str, model_name: str, input_count: int, optimiz
def optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite):
if overwrite or not os.path.exists(ort_model_path):
from optimizer import optimize_model, get_fusion_statistics
from optimizer import optimize_by_onnxruntime, get_fusion_statistics
# Use onnxruntime to optimize model, which will be saved to *_ort.onnx
opt_model = optimize_by_onnxruntime(onnx_model_path,
use_gpu=use_gpu,
@ -339,8 +339,57 @@ def get_latency_result(runtimes, batch_size):
}
def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size):
result = {}
runtimes = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times)
result.update(result_template)
result.update({"io_binding": False})
result.update(get_latency_result(runtimes, batch_size))
return result
def inference_ort_with_io_binding(ort_session, ort_inputs, result_template, repeat_times, ort_output_names, ort_outputs,
output_buffers, max_last_state_size, max_pooler_size, batch_size, device):
result = {}
# Bind inputs and outputs to onnxruntime session
io_binding = ort_session.io_binding()
# Bind inputs to device
for name in ort_inputs.keys():
np_input = torch.from_numpy(ort_inputs[name]).to(device)
io_binding.bind_input(name, np_input.device.type, 0, numpy.longlong, np_input.shape, np_input.data_ptr())
has_pooler = True if len(ort_output_names) == 2 else False
# Bind outputs buffers with the sizes needed if not allocated already
if output_buffers["last_state"] is None:
allocateOutputBuffers(output_buffers, max_last_state_size, max_pooler_size, device, has_pooler)
last_state_buffer = output_buffers["last_state"]
pooler_buffer = output_buffers["pooler"]
io_binding.bind_output(ort_output_names[0], last_state_buffer.device.type, 0, numpy.float32, ort_outputs[0].shape,
last_state_buffer.data_ptr())
if has_pooler:
io_binding.bind_output(ort_output_names[1], pooler_buffer.device.type, 0, numpy.float32, ort_outputs[1].shape,
pooler_buffer.data_ptr())
runtimes = timeit.repeat(lambda: ort_session.run_with_iobinding(io_binding), number=1, repeat=repeat_times)
result.update(result_template)
result.update({"io_binding": True})
result.update(get_latency_result(runtimes, batch_size))
return result
def allocateOutputBuffers(output_buffers, max_last_state_size, max_pooler_size, device, has_pooler=False):
# Allocate output tensors with the largest test size needed. So the allocated memory can be reused
# for each test run.
# dummy last state
if output_buffers["last_state"] is None:
output_buffers["last_state"] = torch.empty(max_last_state_size, dtype=torch.float32, device=device)
# create dummy pooler
if output_buffers["pooler"] is None and has_pooler:
output_buffers["pooler"] = torch.empty(max_pooler_size, dtype=torch.float32, device=device)
def run_onnxruntime(use_gpu, model_names, fp16, batch_sizes, sequence_lengths, repeat_times, input_counts,
optimize_onnx, validate_onnx, cache_dir, onnx_dir, verbose, overwrite):
optimize_onnx, validate_onnx, cache_dir, onnx_dir, verbose, overwrite, disable_ort_io_binding):
import onnxruntime
results = []
@ -372,6 +421,14 @@ def run_onnxruntime(use_gpu, model_names, fp16, batch_sizes, sequence_lengths, r
if ort_session is None:
continue
ort_output_names = [node_arg.name for node_arg in ort_session.get_outputs()]
output_buffers = {"last_state": None, "pooler": None}
device = "cuda" if use_gpu else "cpu"
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
max_last_state_size = numpy.prod(
[max(batch_sizes), max(sequence_lengths),
max(vocab_size, config.hidden_size)])
max_pooler_size = numpy.prod([max(batch_sizes), config.hidden_size])
for batch_size in batch_sizes:
if batch_size <= 0:
continue
@ -379,29 +436,38 @@ def run_onnxruntime(use_gpu, model_names, fp16, batch_sizes, sequence_lengths, r
if max_sequence_length is not None and sequence_length > max_sequence_length:
continue
ort_input = create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names)
ort_inputs = create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names)
logger.info("Run onnxruntime on {} with input shape {}".format(model_name,
[batch_size, sequence_length]))
runtimes = timeit.repeat(lambda: ort_session.run(None, ort_input), number=1, repeat=repeat_times)
result = {
result_template = {
"engine": "onnxruntime",
"version": onnxruntime.__version__,
"device": "cuda" if use_gpu else "cpu",
"device": device,
"optimizer": optimize_onnx,
"fp16": fp16,
"io_binding": False,
"model_name": model_name,
"inputs": num_inputs,
"batch_size": batch_size,
"sequence_length": sequence_length,
"datetime": str(datetime.now()),
}
result.update(get_latency_result(runtimes, batch_size))
logger.info("Run onnxruntime on {} with input shape {}".format(model_name,
[batch_size, sequence_length]))
result = inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size)
logger.info(result)
results.append(result)
if not disable_ort_io_binding:
logger.info("Run onnxruntime with io binding on {} with input shape {}".format(
model_name, [batch_size, sequence_length]))
# Get output sizes from a dummy ort run
ort_outputs = ort_session.run(ort_output_names, ort_inputs)
result = inference_ort_with_io_binding(ort_session, ort_inputs, result_template, repeat_times,
ort_output_names, ort_outputs, output_buffers,
max_last_state_size, max_pooler_size, batch_size, device)
logger.info(result)
results.append(result)
return results
@ -455,6 +521,7 @@ def run_pytorch(use_gpu, model_names, fp16, batch_sizes, sequence_lengths, repea
"device": "cuda" if use_gpu else "cpu",
"optimizer": "",
"fp16": fp16,
"io_binding": "",
"model_name": model_name,
"inputs": 1,
"batch_size": batch_size,
@ -474,9 +541,9 @@ def run_pytorch(use_gpu, model_names, fp16, batch_sizes, sequence_lengths, repea
def output_details(results, csv_filename):
with open(csv_filename, mode="a", newline='') as csv_file:
column_names = [
"engine", "version", "device", "fp16", "optimizer", "model_name", "inputs", "batch_size", "sequence_length",
"datetime", "test_times", "QPS", "average_latency_ms", "latency_variance", "latency_90_percentile",
"latency_95_percentile", "latency_99_percentile"
"engine", "version", "device", "fp16", "optimizer", "io_binding", "model_name", "inputs", "batch_size",
"sequence_length", "datetime", "test_times", "QPS", "average_latency_ms", "latency_variance",
"latency_90_percentile", "latency_95_percentile", "latency_99_percentile"
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
@ -489,7 +556,7 @@ def output_details(results, csv_filename):
def output_summary(results, csv_filename, args):
with open(csv_filename, mode="a", newline='') as csv_file:
header_names = ["model_name", "inputs", "engine", "version", "device", "fp16", "optimizer"]
header_names = ["model_name", "inputs", "engine", "version", "device", "fp16", "optimizer", "io_binding"]
data_names = []
for batch_size in args.batch_sizes:
for sequence_length in args.sequence_lengths:
@ -500,21 +567,23 @@ def output_summary(results, csv_filename, args):
for model_name in args.models:
for input_count in [1, 2, 3]:
for engine_name in args.engines:
row = {}
for result in results:
if result["model_name"] == model_name and result["inputs"] == input_count and result[
"engine"] == engine_name:
headers = {k: v for k, v in result.items() if k in header_names}
if not row:
row.update(headers)
row.update({k: "" for k in data_names})
else:
for k in header_names:
assert row[k] == headers[k]
b = result["batch_size"]
s = result["sequence_length"]
row[f"b{b}_s{s}"] = result["average_latency_ms"]
csv_writer.writerow(row)
for io_binding in [True, False, ""]:
row = {}
for result in results:
if result["model_name"] == model_name and result["inputs"] == input_count and result[
"engine"] == engine_name and result["io_binding"] == io_binding:
headers = {k: v for k, v in result.items() if k in header_names}
if not row:
row.update(headers)
row.update({k: "" for k in data_names})
else:
for k in header_names:
assert row[k] == headers[k]
b = result["batch_size"]
s = result["sequence_length"]
row[f"b{b}_s{s}"] = result["average_latency_ms"]
if row:
csv_writer.writerow(row)
logger.info(f"Summary results are saved to csv file: {csv_filename}")
@ -615,6 +684,12 @@ def parse_arguments():
parser.add_argument("-s", "--sequence_lengths", nargs="+", type=int, default=[8, 16, 32, 64, 128, 256])
parser.add_argument('--disable_ort_io_binding',
required=False,
action='store_true',
help='Disable running ONNX Runtime with binded inputs and outputs. ')
parser.set_defaults(disable_ort_io_binding=False)
args = parser.parse_args()
return args
@ -665,7 +740,8 @@ def main():
try:
results += run_onnxruntime(args.use_gpu, args.models, args.fp16, args.batch_sizes, args.sequence_lengths,
args.test_times, args.input_counts, args.optimize_onnx, args.validate_onnx,
args.cache_dir, args.onnx_dir, args.verbose, args.overwrite)
args.cache_dir, args.onnx_dir, args.verbose, args.overwrite,
args.disable_ort_io_binding)
except:
logger.error(f"Exception", exc_info=True)

View file

@ -81,11 +81,12 @@ def onnxruntime_inference(ort_session, input_ids, past=None, attention_mask=None
def onnxruntime_inference_with_binded_io(ort_session,
input_ids,
last_state,
output_buffers,
max_last_state_size,
last_state_shape,
past=None,
attention_mask=None,
present=None,
max_present_size=None,
present_shape=None,
total_runs=100):
# Bind inputs and outputs to onnxruntime session
@ -94,21 +95,28 @@ def onnxruntime_inference_with_binded_io(ort_session,
io_binding.bind_input('input_ids', input_ids.device.type, 0, numpy.longlong, list(input_ids.size()),
input_ids.data_ptr())
if attention_mask is not None:
io_binding.bind_input('attention_mask', attention_mask.device.type, 0, numpy.float32, list(attention_mask.size()),
attention_mask.data_ptr())
io_binding.bind_input('attention_mask', attention_mask.device.type, 0, numpy.float32,
list(attention_mask.size()), attention_mask.data_ptr())
n_layer = None
if past is not None:
n_layer = len(past)
for i, past_i in enumerate(past):
io_binding.bind_input(f'past_{i}', past[i].device.type, 0, numpy.float32, list(past[i].size()),
past[i].data_ptr())
# Bind outputs
io_binding.bind_output(ort_session.get_outputs()[0].name, last_state.device.type, 0, numpy.float32, last_state_shape,
last_state.data_ptr())
if present is not None:
for i, present_i in enumerate(present):
io_binding.bind_output(f'present_{i}', present[i].device.type, 0, numpy.float32, present_shape,
present[i].data_ptr())
if output_buffers["last_state"] is None or output_buffers["present"] is None:
# Allocate output buffers with the largest size need by current model
allocateOutputBuffers(output_buffers, max_last_state_size, max_present_size, n_layer, input_ids.device)
last_state_buffer = output_buffers["last_state"]
present_buffers = output_buffers["present"]
io_binding.bind_output(ort_session.get_outputs()[0].name, last_state_buffer.device.type, 0, numpy.float32,
last_state_shape, last_state_buffer.data_ptr())
if present_buffers is not None:
for i, present_i in enumerate(present_buffers):
io_binding.bind_output(f'present_{i}', present_buffers[i].device.type, 0, numpy.float32, present_shape,
present_buffers[i].data_ptr())
latency = []
for _ in range(total_runs):
@ -121,10 +129,10 @@ def onnxruntime_inference_with_binded_io(ort_session,
logger.debug("OnnxRuntime with IO binding inference time = {} ms".format(format(average_latency, '.2f')))
# Copy results to cpu
ort_outputs = [last_state[0:numpy.prod(last_state_shape)].reshape(last_state_shape).cpu()]
if present is not None:
for i, present_i in enumerate(present):
ort_outputs.append(present[i][0:numpy.prod(present_shape)].reshape(present_shape).cpu())
ort_outputs = [last_state_buffer[0:numpy.prod(last_state_shape)].reshape(last_state_shape).cpu()]
if present_buffers is not None:
for i, present_i in enumerate(present_buffers):
ort_outputs.append(present_buffers[i][0:numpy.prod(present_shape)].reshape(present_shape).cpu())
return ort_outputs, average_latency
@ -133,25 +141,27 @@ def inference(model,
input_ids,
past=None,
attention_mask=None,
last_state=None,
present=None,
output_buffers=None,
max_last_state_size=None,
last_state_shape=None,
max_present_size=None,
present_shape=None,
total_runs=100,
verify_outputs=True,
with_io_binding=False):
disable_ort_io_binding=False):
outputs, torch_latency = pytorch_inference(model, input_ids, past, attention_mask, total_runs)
ort_outputs, ort_latency = onnxruntime_inference(ort_session, input_ids, past, attention_mask, total_runs)
latencies = [torch_latency, ort_latency]
if with_io_binding:
ort_io_outputs, ort_io_latency = onnxruntime_inference_with_binded_io(ort_session, input_ids, last_state,
last_state_shape, past, attention_mask, present,
if not disable_ort_io_binding:
ort_io_outputs, ort_io_latency = onnxruntime_inference_with_binded_io(ort_session, input_ids, output_buffers,
max_last_state_size, last_state_shape,
past, attention_mask, max_present_size,
present_shape, total_runs)
latencies.append(ort_io_latency)
if verify_outputs:
logger.debug('Verifying Pytorch and ONNX Runtime outputs.')
verify_ort_outputs(model, outputs, ort_outputs)
if with_io_binding:
if not disable_ort_io_binding:
logger.debug('Verifying Pytorch and ONNX Runtime with io binding outputs.')
verify_ort_outputs(model, outputs, ort_io_outputs)
@ -172,6 +182,18 @@ def verify_ort_outputs(model, torch_outputs, ort_outputs):
logger.warning(f'PyTorch and OnnxRuntime results are not all close.')
def allocateOutputBuffers(output_buffers, max_last_state_size, max_present_size, n_layer, device):
# Allocate output tensors with the largest test size needed. So the allocated memory can be reused
# for each test run.
# dummy last state
if output_buffers["last_state"] is None:
output_buffers["last_state"] = torch.empty(max_last_state_size, dtype=torch.float32, device=device)
# create dummy present
if n_layer is not None and output_buffers["present"] is None:
present_buffers = [torch.empty(max_present_size, dtype=torch.float32, device=device) for _ in range(n_layer)]
output_buffers["present"] = present_buffers
def parse_arguments():
parser = argparse.ArgumentParser()
@ -211,12 +233,11 @@ def parse_arguments():
help='Use optimizer.py to optimize onnx model')
parser.set_defaults(optimize_onnx=False)
parser.add_argument('-i',
'--with_io_binding',
parser.add_argument('--disable_ort_io_binding',
required=False,
action='store_true',
help='Run ONNX Runtime with binded inputs and outputs. ')
parser.set_defaults(with_io_binding=False)
help='Disable running ONNX Runtime with binded inputs and outputs. ')
parser.set_defaults(disable_ort_io_binding=False)
parser.add_argument('--use_gpu', required=False, action='store_true')
parser.set_defaults(use_gpu=False)
@ -285,7 +306,7 @@ def export_onnx(model, config, tokenizer, device, output_dir, use_LMHead=False,
# Shape of output tensors:
# last_state: (batch_size, all_seq_len, hidden_size)
# present_{i}: (2, batch_size, num_heads, all_seq_len, hidden_size/num_heads)
dynamic_axes = {'input_ids': {0: 'batch_size', 1 : 'seq_len'}, output_names[0]: {0: 'batch_size', 1: 'all_seq_len'}}
dynamic_axes = {'input_ids': {0: 'batch_size', 1: 'seq_len'}, output_names[0]: {0: 'batch_size', 1: 'all_seq_len'}}
for name in past_names:
dynamic_axes[name] = {1: 'batch_size', 3: 'past_seq_len'}
for name in present_names:
@ -312,7 +333,8 @@ def export_onnx(model, config, tokenizer, device, output_dir, use_LMHead=False,
logger.debug(f"present_0 shape={outputs[1][0].shape}")
torch.onnx.export(model,
args=(dummy_input_ids, tuple(dummy_past), dummy_mask) if use_attention_mask else (dummy_input_ids, tuple(dummy_past)),
args=(dummy_input_ids, tuple(dummy_past), dummy_mask) if use_attention_mask else
(dummy_input_ids, tuple(dummy_past)),
f=export_model_path,
input_names=['input_ids'] + past_names + (['attention_mask'] if use_attention_mask else []),
output_names=output_names,
@ -376,29 +398,21 @@ def main():
logger.info(f"Start inferencing onnx model: {onnx_model_path}")
session = onnxruntime.InferenceSession(onnx_model_path, sess_options)
dummy_present = None
dummy_last_state = None
# Calculate the largest size needed for each output
max_batch_size = max(args.batch_sizes)
max_seq_len = max(args.sequence_lengths)
if args.with_io_binding:
# Allocate output tensors with the largest test size needed. So the allocated memory can be reused
# for each test run.
# create dummy present
present_state_size = numpy.prod([
2, max_batch_size, config.num_attention_heads, max_seq_len + 1,
int(config.hidden_size / config.num_attention_heads)
])
dummy_present = [
torch.empty(present_state_size, dtype=torch.float32, device=device) for _ in range(config.n_layer)
]
max_present_size = numpy.prod([
2, max_batch_size, config.num_attention_heads, max_seq_len + 1,
int(config.hidden_size / config.num_attention_heads)
])
# dummy last state
if use_LMHead:
last_state_size = numpy.prod([max_batch_size, 1, config.vocab_size])
else:
last_state_size = numpy.prod([max_batch_size, 1, config.hidden_size])
dummy_last_state = torch.empty(last_state_size).to(device)
# dummy last state
if use_LMHead:
max_last_state_size = numpy.prod([max_batch_size, 1, config.vocab_size])
else:
max_last_state_size = numpy.prod([max_batch_size, 1, config.hidden_size])
output_buffers = {"last_state": None, "present": None}
for batch_size in args.batch_sizes:
for sequence_length in args.sequence_lengths:
past_shape = [
@ -414,7 +428,7 @@ def main():
dummy_mask = torch.ones([batch_size, 1], dtype=torch.float32, device=device) if use_attention_mask else None
# Calculate the expected output shapes
last_state_shape = [batch_size, 1, config.hidden_size]
last_state_shape = [batch_size, 1, config.vocab_size] if use_LMHead else [batch_size, 1, config.hidden_size]
present_shape = [
2, batch_size, config.num_attention_heads, sequence_length + 1,
int(config.hidden_size / config.num_attention_heads)
@ -425,14 +439,15 @@ def main():
dummy_input_ids,
dummy_past,
dummy_mask,
dummy_last_state,
dummy_present,
output_buffers,
max_last_state_size,
last_state_shape,
max_present_size,
present_shape,
args.test_times,
verify_outputs=args.validate_onnx,
with_io_binding=args.with_io_binding)
ort_io_latency_info = f", ort_io_latency={latencies[2]}" if args.with_io_binding else ""
disable_ort_io_binding=args.disable_ort_io_binding)
ort_io_latency_info = f", ort_io_latency={latencies[2]}" if not args.disable_ort_io_binding else ""
logger.info(
f"batch_size={batch_size}, sequence_length={sequence_length}, torch_latency={latencies[0]}, ort_latency={latencies[1]}{ort_io_latency_info}"
)