mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
* Add t-test to compare two experiments * Ranking based on pair-wise T-test results and a custom scoring function
814 lines
37 KiB
Python
814 lines
37 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
# license information.
|
|
# --------------------------------------------------------------------------
|
|
# This script helps onnx conversion and validation for GPT2 model with past state.
|
|
import os
|
|
import logging
|
|
import torch
|
|
import shutil
|
|
import random
|
|
import numpy
|
|
import time
|
|
import re
|
|
import pickle
|
|
from pathlib import Path
|
|
from typing import List, Dict, Tuple, Union
|
|
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config, TFGPT2Model
|
|
from benchmark_helper import Precision
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
PRETRAINED_GPT2_MODELS = ['distilgpt2', 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
|
|
|
|
DEFAULT_TOLERANCE = {Precision.FLOAT32: 0.0005, Precision.FLOAT16: 0.2, Precision.INT8: 3.0}
|
|
|
|
|
|
class GPT2ModelNoPastState(GPT2Model):
|
|
""" Here we wrap a class to disable past state output.
|
|
"""
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
def forward(self, input_ids):
|
|
return super().forward(input_ids, use_cache=False, return_dict=False)
|
|
|
|
|
|
class TFGPT2ModelNoPastState(TFGPT2Model):
|
|
""" Here we wrap a class to disable past state output.
|
|
"""
|
|
def __init__(self, config):
|
|
config.use_cache = False
|
|
super().__init__(config)
|
|
|
|
def forward(self, input_ids):
|
|
return super().call(input_ids, use_cache=False)
|
|
|
|
|
|
class MyGPT2Model(GPT2Model):
|
|
""" Here we wrap a class for Onnx model conversion for GPT2Model with past state.
|
|
"""
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
@staticmethod
|
|
def post_process(result, num_layer):
|
|
if isinstance(result[1][0], tuple) or isinstance(result[1][0], list):
|
|
assert len(result[1]) == num_layer and len(result[1][0]) == 2
|
|
#assert len(result[1][0][0].shape) == 4 and result[1][0][0].shape == result[1][0][1].shape
|
|
present = []
|
|
for i in range(num_layer):
|
|
# Since transformers v4.*, past key and values are separated outputs.
|
|
# Here we concate them into one tensor to be compatible with Attention operator.
|
|
present.append(torch.cat((result[1][i][0].unsqueeze(0), result[1][i][1].unsqueeze(0)), dim=0))
|
|
return (result[0], tuple(present))
|
|
|
|
return result
|
|
|
|
def forward(self, input_ids, position_ids, attention_mask, *past):
|
|
result = super().forward(input_ids,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past,
|
|
return_dict=False)
|
|
return MyGPT2Model.post_process(result, self.config.n_layer)
|
|
|
|
|
|
class MyGPT2LMHeadModel(GPT2LMHeadModel):
|
|
""" Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state.
|
|
"""
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
def forward(self, input_ids, position_ids, attention_mask, *past):
|
|
result = super().forward(input_ids,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past,
|
|
return_dict=False)
|
|
|
|
return MyGPT2Model.post_process(result, self.config.n_layer)
|
|
|
|
|
|
class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel):
|
|
""" Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and no padding.
|
|
When you always use batch_size=1 in inference, there is no padding in inputs. In such case, position_ids
|
|
and attention_mask need no be in inputs.
|
|
"""
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
def forward(self, input_ids, *past):
|
|
result = super().forward(input_ids, past_key_values=past, return_dict=False)
|
|
|
|
return MyGPT2Model.post_process(result, self.config.n_layer)
|
|
|
|
|
|
# Maps model class name to a tuple of model class, name of first output and use padding or not
|
|
MODEL_CLASSES = {
|
|
'GPT2LMHeadModel': (MyGPT2LMHeadModel, 'logits', True),
|
|
'GPT2LMHeadModel_NoPadding': (MyGPT2LMHeadModel_NoPadding, 'logits', False),
|
|
'GPT2Model': (MyGPT2Model, 'last_state', True),
|
|
}
|
|
|
|
|
|
class Gpt2Inputs:
|
|
def __init__(self, input_ids, position_ids, attention_mask, past):
|
|
self.input_ids: torch.LongTensor = input_ids
|
|
self.position_ids: torch.LongTensor = position_ids
|
|
self.attention_mask: Union[torch.FloatTensor, torch.HalfTensor] = attention_mask
|
|
self.past: Union[List[torch.FloatTensor], List[torch.HalfTensor]] = past
|
|
|
|
def to_list(self) -> List:
|
|
input_list = [v for v in [self.input_ids, self.position_ids, self.attention_mask] if v is not None]
|
|
if self.past:
|
|
input_list.extend(self.past)
|
|
|
|
return input_list
|
|
|
|
def to_tuple(self) -> Tuple:
|
|
return tuple(v for v in [self.input_ids, self.position_ids, self.attention_mask, self.past] if v is not None)
|
|
|
|
def to_fp32(self):
|
|
attention_mask = self.attention_mask.to(dtype=torch.float32) if self.attention_mask is not None else None
|
|
past = [p.to(dtype=torch.float32) for p in self.past]
|
|
return Gpt2Inputs(self.input_ids, self.position_ids, attention_mask, past)
|
|
|
|
|
|
class Gpt2Helper:
|
|
""" A helper class for Gpt2 model conversion, inference and verification.
|
|
"""
|
|
@staticmethod
|
|
def get_dummy_inputs(batch_size: int,
|
|
past_sequence_length: int,
|
|
sequence_length: int,
|
|
num_attention_heads: int,
|
|
hidden_size: int,
|
|
num_layer: int,
|
|
vocab_size: int,
|
|
device: torch.device,
|
|
float16: bool = False,
|
|
has_position_ids: bool = True,
|
|
has_attention_mask: bool = True) -> Gpt2Inputs:
|
|
""" Create random inputs for GPT2 model.
|
|
Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors.
|
|
"""
|
|
float_type = torch.float16 if float16 else torch.float32
|
|
past_shape = [2, batch_size, num_attention_heads, past_sequence_length, int(hidden_size / num_attention_heads)]
|
|
|
|
past = [(torch.rand(past_shape, dtype=float_type, device=device) * 2.0 - 1.0) for _ in range(num_layer)]
|
|
input_ids = torch.randint(low=0,
|
|
high=vocab_size - 1,
|
|
size=(batch_size, sequence_length),
|
|
dtype=torch.int64,
|
|
device=device)
|
|
|
|
attention_mask = None
|
|
if has_attention_mask:
|
|
total_sequence_length = past_sequence_length + sequence_length
|
|
attention_mask = torch.ones([batch_size, total_sequence_length], dtype=float_type, device=device)
|
|
if total_sequence_length >= 2:
|
|
padding_position = random.randint(0, total_sequence_length - 1) # test input with padding.
|
|
attention_mask[:, padding_position] = 0
|
|
|
|
# Deduce position_ids from attention mask
|
|
position_ids = None
|
|
if has_position_ids:
|
|
position_ids = (attention_mask.long().cumsum(-1) - 1)
|
|
position_ids.masked_fill_(position_ids < 0, 0)
|
|
position_ids = position_ids[:, past_sequence_length:]
|
|
|
|
return Gpt2Inputs(input_ids, position_ids, attention_mask, past)
|
|
|
|
@staticmethod
|
|
def get_output_shapes(batch_size: int,
|
|
past_sequence_length: int,
|
|
sequence_length: int,
|
|
config: GPT2Config,
|
|
model_class: str = "GPT2LMHeadModel") -> Dict[str, List[int]]:
|
|
""" Returns a dictionary with output name as key, and shape as value.
|
|
"""
|
|
num_attention_heads = config.num_attention_heads
|
|
hidden_size = config.hidden_size
|
|
num_layer = config.num_hidden_layers
|
|
vocab_size = config.vocab_size
|
|
|
|
output_name = MODEL_CLASSES[model_class][1]
|
|
|
|
last_state_shape = [batch_size, sequence_length, vocab_size if output_name == "logits" else hidden_size]
|
|
present_state_shape = [
|
|
2, batch_size, num_attention_heads, past_sequence_length + sequence_length,
|
|
int(hidden_size / num_attention_heads)
|
|
]
|
|
|
|
output_shapes = {output_name: last_state_shape}
|
|
for i in range(num_layer):
|
|
output_shapes["present_" + str(i)] = present_state_shape
|
|
|
|
return output_shapes
|
|
|
|
@staticmethod
|
|
def auto_increase_buffer_size(output_buffers, output_shapes):
|
|
for key in output_shapes:
|
|
assert key in output_buffers
|
|
buffer = output_buffers[key]
|
|
if numpy.prod(output_shapes[key]) > buffer.nelement():
|
|
output_buffers[key] = torch.empty(numpy.prod(output_shapes[key]),
|
|
dtype=buffer.dtype,
|
|
device=buffer.device)
|
|
|
|
@staticmethod
|
|
def get_output_buffers(output_shapes, device, is_float16=False):
|
|
""" Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape.
|
|
"""
|
|
data_type = torch.float16 if is_float16 else torch.float32
|
|
|
|
output_buffers = {}
|
|
for name, shape in output_shapes.items():
|
|
output_buffers[name] = torch.empty(numpy.prod(shape), dtype=data_type, device=device)
|
|
return output_buffers
|
|
|
|
@staticmethod
|
|
def diff_outputs(torch_outputs, ort_outputs, relative=False):
|
|
""" Returns the maximum difference between PyTorch and OnnxRuntime outputs.
|
|
"""
|
|
expected_outputs = torch_outputs[0].cpu().numpy()
|
|
diff = numpy.abs(expected_outputs - ort_outputs[0])
|
|
if relative:
|
|
return numpy.amax(diff / (numpy.abs(expected_outputs) + 1e-6))
|
|
else:
|
|
return numpy.amax(diff)
|
|
|
|
@staticmethod
|
|
def compare_outputs(torch_outputs, ort_outputs, rtol=1e-03, atol=1e-03):
|
|
""" Returns True if torch and ORT outputs are close for given thresholds, and False otherwise.
|
|
"""
|
|
is_close = numpy.allclose(ort_outputs[0], torch_outputs[0].cpu().numpy(), rtol=rtol, atol=atol)
|
|
logger.debug(f'PyTorch and OnnxRuntime output 0 (last_state) are close: {is_close}')
|
|
|
|
is_all_close = is_close
|
|
num_layers = len(ort_outputs) - 1
|
|
|
|
for layer in range(num_layers):
|
|
is_close = numpy.allclose(ort_outputs[1 + layer],
|
|
torch_outputs[1][layer].cpu().numpy(),
|
|
rtol=rtol,
|
|
atol=atol)
|
|
logger.debug(f'PyTorch and OnnxRuntime layer {layer} state (present_{layer}) are close:{is_close}')
|
|
is_all_close = is_all_close and is_close
|
|
|
|
if not is_all_close:
|
|
max_abs_diff = Gpt2Helper.diff_outputs(torch_outputs, ort_outputs)
|
|
logger.info(f'PyTorch and OnnxRuntime results are not all close: max_abs_diff={max_abs_diff:.5f}')
|
|
|
|
return is_all_close
|
|
|
|
@staticmethod
|
|
def compare_outputs_v2(torch_outputs, ort_outputs, atol=1e-06):
|
|
"""Compare outputs from PyTorch and OnnxRuntime
|
|
|
|
Args:
|
|
torch_outputs (Tuple[Torch.Tensor]): PyTorch model output
|
|
ort_outputs (List[numpy.ndarray]): OnnxRuntime output
|
|
atol (float, optional): Absolute tollerance. Defaults to 1e-06.
|
|
|
|
Returns:
|
|
is_all_close(bool): whether all elements are close.
|
|
max_abs_diff(float): maximum absolute difference.
|
|
messages(str): a list of debug message for each output
|
|
"""
|
|
is_all_close = True
|
|
is_top1_matched = False
|
|
max_diffs = []
|
|
messages = []
|
|
for i in range(len(ort_outputs)):
|
|
ort_output = ort_outputs[i]
|
|
torch_output = (torch_outputs[0] if i == 0 else torch_outputs[1][i - 1]).cpu().numpy()
|
|
is_close = numpy.allclose(ort_output, torch_output, atol=atol, rtol=0)
|
|
max_diffs.append(numpy.amax(numpy.abs(torch_output - ort_output)))
|
|
is_all_close = is_all_close and is_close
|
|
|
|
if numpy.isnan(torch_output).any():
|
|
logger.debug(f'PyTorch output {i} has nan')
|
|
if numpy.isinf(torch_output).any():
|
|
logger.debug(f'PyTorch output {i} has inf')
|
|
if numpy.isnan(ort_output).any():
|
|
logger.debug(f'ORT output {i} has nan')
|
|
if numpy.isinf(ort_output).any():
|
|
logger.debug(f'ORT output {i} has inf')
|
|
|
|
diff = numpy.fabs(ort_output - torch_output)
|
|
idx = numpy.unravel_index(diff.argmax(), diff.shape)
|
|
messages.append(
|
|
f'diff={diff[idx]:.9f} index={idx} ort={ort_output[idx]:.9f} torch={float(torch_output[idx]):.9f}')
|
|
|
|
if i == 0: # logits
|
|
ort_max_index = numpy.unravel_index(numpy.argmax(ort_output, axis=None), ort_output.shape)
|
|
torch_max_index = numpy.unravel_index(numpy.argmax(torch_output, axis=None), torch_output.shape)
|
|
is_top1_matched = numpy.array_equal(ort_max_index, torch_max_index)
|
|
|
|
max_diff_output_index = max_diffs.index(max(max_diffs))
|
|
return is_all_close, max(max_diffs), max_diff_output_index, messages, is_top1_matched
|
|
|
|
@staticmethod
|
|
def export_onnx(model,
|
|
device,
|
|
onnx_model_path: str,
|
|
verbose: bool = False,
|
|
use_external_data_format: bool = False,
|
|
has_position_ids: bool = True,
|
|
has_attention_mask: bool = True):
|
|
""" Export GPT-2 model with past state to ONNX model.
|
|
"""
|
|
config: GPT2Config = model.config
|
|
num_layer = config.n_layer
|
|
dummy_inputs = Gpt2Helper.get_dummy_inputs(batch_size=1,
|
|
past_sequence_length=1,
|
|
sequence_length=1,
|
|
num_attention_heads=config.num_attention_heads,
|
|
hidden_size=config.hidden_size,
|
|
num_layer=num_layer,
|
|
vocab_size=config.vocab_size,
|
|
device=device,
|
|
float16=False,
|
|
has_position_ids=has_position_ids,
|
|
has_attention_mask=has_attention_mask)
|
|
input_list = dummy_inputs.to_list()
|
|
|
|
with torch.no_grad():
|
|
outputs = model(*input_list)
|
|
|
|
past_names = [f'past_{i}' for i in range(num_layer)]
|
|
present_names = [f'present_{i}' for i in range(num_layer)]
|
|
|
|
# GPT2Model outputs last_state; GPT2LMHeadModel outputs logits (prediction_scores)
|
|
assert outputs[0].shape[2] == config.vocab_size or outputs[0].shape[2] == config.hidden_size
|
|
output_names = ["logits" if outputs[0].shape[2] == config.vocab_size else "last_state"] + present_names
|
|
|
|
# Shape of input tensors:
|
|
# input_ids: (batch_size, seq_len)
|
|
# past_{i}: (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads)
|
|
# attention_mask: (batch_size, past_seq_len + seq_len)
|
|
# Shape of output tensors:
|
|
# last_state: (batch_size, seq_len, hidden_size)
|
|
# or logits: (batch_size, seq_len, vocab_size)
|
|
# present_{i}: (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads)
|
|
dynamic_axes = {'input_ids': {0: 'batch_size', 1: 'seq_len'}, output_names[0]: {0: 'batch_size', 1: 'seq_len'}}
|
|
for name in past_names:
|
|
dynamic_axes[name] = {1: 'batch_size', 3: 'past_seq_len'}
|
|
for name in present_names:
|
|
dynamic_axes[name] = {1: 'batch_size', 3: 'total_seq_len'}
|
|
|
|
input_names = ['input_ids']
|
|
if has_position_ids:
|
|
dynamic_axes['position_ids'] = {0: 'batch_size', 1: 'seq_len'}
|
|
input_names.append('position_ids')
|
|
if has_attention_mask:
|
|
dynamic_axes['attention_mask'] = {0: 'batch_size', 1: 'total_seq_len'}
|
|
input_names.append('attention_mask')
|
|
input_names.extend(past_names)
|
|
|
|
assert len(outputs) == 2 and len(outputs[1]) == num_layer
|
|
|
|
logger.info(
|
|
f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}"
|
|
)
|
|
|
|
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
torch.onnx.export(model,
|
|
args=tuple(input_list),
|
|
f=onnx_model_path,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
example_outputs=outputs,
|
|
dynamic_axes=dynamic_axes,
|
|
opset_version=11,
|
|
do_constant_folding=True,
|
|
use_external_data_format=use_external_data_format,
|
|
verbose=verbose)
|
|
|
|
@staticmethod
|
|
def optimize_onnx(onnx_model_path,
|
|
optimized_model_path,
|
|
is_float16,
|
|
num_attention_heads,
|
|
hidden_size,
|
|
use_external_data_format=False,
|
|
**kwargs):
|
|
""" Optimize ONNX model with an option to convert it to use mixed precision.
|
|
"""
|
|
from optimizer import optimize_model
|
|
|
|
from fusion_options import FusionOptions
|
|
optimization_options = FusionOptions('gpt2')
|
|
#optimization_options.enable_gelu = False
|
|
#optimization_options.enable_layer_norm = False
|
|
#optimization_options.enable_attention = False
|
|
m = optimize_model(onnx_model_path,
|
|
model_type='gpt2',
|
|
num_heads=num_attention_heads,
|
|
hidden_size=hidden_size,
|
|
opt_level=0,
|
|
optimization_options=optimization_options,
|
|
use_gpu=False)
|
|
|
|
if is_float16:
|
|
op_full_list = set([node.op_type for node in m.nodes()])
|
|
op_block_list = set(kwargs["op_block_list"]) if "op_block_list" in kwargs else set()
|
|
op_remain_list = op_full_list.difference(op_block_list)
|
|
logger.info(f"op_block_list={op_block_list} op_remain_list={op_remain_list}")
|
|
m.convert_float_to_float16(use_symbolic_shape_infer=True, **kwargs)
|
|
|
|
m.save_model_to_file(optimized_model_path, use_external_data_format)
|
|
|
|
@staticmethod
|
|
def pytorch_inference(model, inputs: Gpt2Inputs, total_runs: int = 0):
|
|
""" Run inference of PyTorch model, and returns average latency in ms when total_runs > 0 besides outputs.
|
|
"""
|
|
logger.debug("start pytorch_inference")
|
|
|
|
# Convert it to fp32 as the PyTroch model cannot deal with half input.
|
|
input_list = inputs.to_fp32().to_list()
|
|
|
|
with torch.no_grad():
|
|
outputs = model(*input_list)
|
|
|
|
if total_runs == 0:
|
|
return outputs
|
|
|
|
latency = []
|
|
with torch.no_grad():
|
|
for _ in range(total_runs):
|
|
start = time.time()
|
|
outputs = model(*input_list)
|
|
latency.append(time.time() - start)
|
|
|
|
average_latency = sum(latency) * 1000 / len(latency)
|
|
logger.debug("PyTorch inference time = {} ms".format(format(average_latency, '.2f')))
|
|
|
|
return outputs, average_latency
|
|
|
|
@staticmethod
|
|
def onnxruntime_inference(ort_session, inputs: Gpt2Inputs, total_runs: int = 0):
|
|
""" Run inference of ONNX model, and returns average latency in ms when total_runs > 0 besides outputs.
|
|
"""
|
|
logger.debug(f"start onnxruntime_inference")
|
|
|
|
ort_inputs = {'input_ids': numpy.ascontiguousarray(inputs.input_ids.cpu().numpy())}
|
|
|
|
if inputs.past is not None:
|
|
for i, past_i in enumerate(inputs.past):
|
|
ort_inputs[f'past_{i}'] = numpy.ascontiguousarray(past_i.cpu().numpy())
|
|
|
|
if inputs.attention_mask is not None:
|
|
ort_inputs['attention_mask'] = numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy())
|
|
|
|
if inputs.position_ids is not None:
|
|
ort_inputs['position_ids'] = numpy.ascontiguousarray(inputs.position_ids.cpu().numpy())
|
|
|
|
ort_outputs = ort_session.run(None, ort_inputs)
|
|
if total_runs == 0:
|
|
return ort_outputs
|
|
|
|
latency = []
|
|
for _ in range(total_runs):
|
|
start = time.time()
|
|
ort_outputs = ort_session.run(None, ort_inputs)
|
|
latency.append(time.time() - start)
|
|
|
|
average_latency = sum(latency) * 1000 / len(latency)
|
|
logger.debug("OnnxRuntime Inference time = {} ms".format(format(average_latency, '.2f')))
|
|
|
|
return ort_outputs, average_latency
|
|
|
|
@staticmethod
|
|
def prepare_io_binding(ort_session, input_ids, position_ids, attention_mask, past, output_buffers, output_shapes):
|
|
""" Returnas IO binding object for a session.
|
|
"""
|
|
|
|
# Bind inputs and outputs to onnxruntime session
|
|
io_binding = ort_session.io_binding()
|
|
|
|
# Bind inputs
|
|
assert input_ids.is_contiguous()
|
|
io_binding.bind_input('input_ids', input_ids.device.type, 0, numpy.longlong, list(input_ids.size()),
|
|
input_ids.data_ptr())
|
|
|
|
data_type = output_buffers[ort_session.get_outputs()[0].name].dtype
|
|
float_type = numpy.float16 if data_type == torch.float16 else numpy.float32
|
|
|
|
if past is not None:
|
|
for i, past_i in enumerate(past):
|
|
assert past_i.is_contiguous()
|
|
|
|
data_ptr = past_i.data_ptr()
|
|
if data_ptr == 0:
|
|
# When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
|
|
# Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
|
|
data_ptr = input_ids.data_ptr()
|
|
|
|
io_binding.bind_input(f'past_{i}', past_i.device.type, 0, float_type, list(past_i.size()), data_ptr)
|
|
|
|
if attention_mask is not None:
|
|
assert attention_mask.is_contiguous()
|
|
io_binding.bind_input('attention_mask', attention_mask.device.type, 0, float_type,
|
|
list(attention_mask.size()), attention_mask.data_ptr())
|
|
|
|
if position_ids is not None:
|
|
assert position_ids.is_contiguous()
|
|
io_binding.bind_input('position_ids', position_ids.device.type, 0, numpy.longlong,
|
|
list(position_ids.size()), position_ids.data_ptr())
|
|
|
|
# Bind outputs
|
|
for output in ort_session.get_outputs():
|
|
output_name = output.name
|
|
output_buffer = output_buffers[output_name]
|
|
logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
|
|
io_binding.bind_output(output_name, output_buffer.device.type, 0, float_type, output_shapes[output_name],
|
|
output_buffer.data_ptr())
|
|
|
|
return io_binding
|
|
|
|
@staticmethod
|
|
def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
|
|
""" Copy results to cpu. Returns a list of numpy array.
|
|
"""
|
|
ort_outputs = []
|
|
for output in ort_session.get_outputs():
|
|
output_name = output.name
|
|
buffer = output_buffers[output_name]
|
|
shape = output_shapes[output_name]
|
|
copy_tensor = buffer[0:numpy.prod(shape)].reshape(shape).clone().detach()
|
|
if return_numpy:
|
|
ort_outputs.append(copy_tensor.cpu().numpy())
|
|
else:
|
|
ort_outputs.append(copy_tensor)
|
|
return ort_outputs
|
|
|
|
@staticmethod
|
|
def onnxruntime_inference_with_binded_io(ort_session,
|
|
inputs: Gpt2Inputs,
|
|
output_buffers: Dict[str, torch.Tensor],
|
|
output_shapes: Dict[str, List[int]],
|
|
total_runs: int = 0,
|
|
return_numpy: bool = True,
|
|
include_copy_output_latency: bool = False):
|
|
""" Inference with IO binding. Returns outputs, and optional latency when total_runs > 0.
|
|
"""
|
|
logger.debug(f"start onnxruntime_inference_with_binded_io")
|
|
|
|
# Bind inputs and outputs to onnxruntime session
|
|
io_binding = Gpt2Helper.prepare_io_binding(ort_session, inputs.input_ids, inputs.position_ids,
|
|
inputs.attention_mask, inputs.past, output_buffers, output_shapes)
|
|
|
|
# Run onnxruntime with io binding
|
|
ort_session.run_with_iobinding(io_binding)
|
|
|
|
# Copy results to cpu for verification
|
|
ort_outputs = Gpt2Helper.get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes,
|
|
return_numpy)
|
|
|
|
if total_runs == 0:
|
|
return ort_outputs
|
|
|
|
latency = []
|
|
for _ in range(total_runs):
|
|
start = time.time()
|
|
# Run onnxruntime with io binding
|
|
ort_session.run_with_iobinding(io_binding)
|
|
if include_copy_output_latency:
|
|
_ = Gpt2Helper.get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes,
|
|
return_numpy)
|
|
latency.append(time.time() - start)
|
|
|
|
average_latency = sum(latency) * 1000 / len(latency)
|
|
logger.debug("OnnxRuntime with IO binding inference time = {} ms".format(format(average_latency, '.2f')))
|
|
|
|
return ort_outputs, average_latency
|
|
|
|
@staticmethod
|
|
def save_outputs(i, ort_outputs, torch_outputs):
|
|
with open(f'ort_outputs_{i}.pickle', 'wb') as f:
|
|
pickle.dump(ort_outputs, f)
|
|
logger.info(f"ORT output are saved to ort_outputs_{i}.pickle")
|
|
|
|
with open(f'torch_outputs_{i}.pickle', 'wb') as f:
|
|
pickle.dump(torch_outputs, f)
|
|
logger.info(f"Torch output are saved to torch_outputs_{i}.pickle")
|
|
|
|
@staticmethod
|
|
def save_inputs(i, dummy_inputs, ort_outputs, torch_outputs):
|
|
with open(f'dummy_inputs_{i}.pickle', 'wb') as f:
|
|
pickle.dump(dummy_inputs, f)
|
|
logger.info(f"inputs are saved to dummy_inputs_{i}.pickle")
|
|
|
|
@staticmethod
|
|
def test_parity(ort_session,
|
|
model,
|
|
device,
|
|
is_float16=False,
|
|
rtol=5e-4,
|
|
atol=5e-4,
|
|
test_cases_per_run=10000,
|
|
total_runs=1,
|
|
use_io_binding=True,
|
|
model_class="GPT2LMHeadModel",
|
|
has_position_ids=True,
|
|
has_attention_mask=True,
|
|
verbose=False,
|
|
enable_pickle_output=False):
|
|
""" Generate random inputs and compare the results of PyTorch and Onnx Runtime.
|
|
"""
|
|
|
|
config: GPT2Config = model.config
|
|
|
|
logger.info(
|
|
f"Running parity test (atol={atol}, test_cases={test_cases_per_run}, runs={total_runs}, use_io_binding={use_io_binding}, model_class={model_class}, is_float16={is_float16}) ..."
|
|
)
|
|
|
|
max_batch_size = 8
|
|
max_past_seq_len = 4 # Do not use large number here for higher chance of hitting empty past (past_seq_len=0)
|
|
max_seq_len = 2
|
|
|
|
output_buffers = None
|
|
if use_io_binding:
|
|
max_output_shapes = Gpt2Helper.get_output_shapes(max_batch_size, max_past_seq_len, max_seq_len, config,
|
|
model_class)
|
|
output_buffers = Gpt2Helper.get_output_buffers(max_output_shapes, device, is_float16)
|
|
|
|
passed_test_cases = 0
|
|
top1_matched_cases = 0
|
|
|
|
max_abs_diff_list = []
|
|
top1_matched_cases_per_run = [0] * total_runs
|
|
total_test_cases = test_cases_per_run * total_runs
|
|
for i in range(total_test_cases):
|
|
run_id = int(i / test_cases_per_run)
|
|
sequence_length = random.randint(1, max_seq_len)
|
|
past_sequence_length = random.randint(0, max_past_seq_len)
|
|
batch_size = random.randint(1, max_batch_size)
|
|
|
|
logger.debug(
|
|
f"Running parity test for batch_size={batch_size} past_sequence_length={past_sequence_length}...")
|
|
dummy_inputs = Gpt2Helper.get_dummy_inputs(batch_size, past_sequence_length, sequence_length,
|
|
config.num_attention_heads, config.hidden_size, config.n_layer,
|
|
config.vocab_size, device, is_float16, has_position_ids,
|
|
has_attention_mask)
|
|
outputs = Gpt2Helper.pytorch_inference(model, dummy_inputs)
|
|
if use_io_binding:
|
|
ort_outputs = Gpt2Helper.onnxruntime_inference(ort_session, dummy_inputs)
|
|
else:
|
|
output_shapes = Gpt2Helper.get_output_shapes(batch_size, past_sequence_length, sequence_length, config,
|
|
model_class)
|
|
ort_outputs = Gpt2Helper.onnxruntime_inference_with_binded_io(ort_session, dummy_inputs, output_buffers,
|
|
output_shapes)
|
|
|
|
is_all_close, max_abs_diff, max_diff_output_index, messages, is_top1_matched = Gpt2Helper.compare_outputs_v2(
|
|
outputs, ort_outputs, atol=atol)
|
|
if not numpy.isnan(max_abs_diff):
|
|
max_abs_diff_list.append(max_abs_diff)
|
|
if is_all_close:
|
|
passed_test_cases += 1
|
|
if is_top1_matched:
|
|
top1_matched_cases += 1
|
|
top1_matched_cases_per_run[run_id] += 1
|
|
|
|
if verbose and not is_all_close:
|
|
logger.info(
|
|
f"test_case={i} batch_size={batch_size} past_sequence_length={past_sequence_length} sequence_length={sequence_length} MaxDiff={max_abs_diff}"
|
|
)
|
|
for i, message in enumerate(messages):
|
|
logger.info(f"\t{i}: Name={ort_session.get_outputs()[i].name}, {message}")
|
|
|
|
# Collect data for debugging
|
|
if enable_pickle_output and (numpy.isnan(max_abs_diff) or max_abs_diff > 100 * atol):
|
|
Gpt2Helper.save_inputs(i, dummy_inputs)
|
|
Gpt2Helper.save_outputs(i, ort_outputs, outputs)
|
|
|
|
if max_abs_diff_list:
|
|
result = {
|
|
f"max_diff_percentile_{p}": "{:.5f}".format(numpy.percentile(max_abs_diff_list, p))
|
|
for p in [50, 90, 95, 99]
|
|
}
|
|
else:
|
|
result = {f"max_diff_percentile_{p}": "nan" for p in [50, 90, 95, 99]}
|
|
|
|
result["top1_match_rate"] = top1_matched_cases * 1.0 / total_test_cases
|
|
result["top1_match_rate_per_run"] = [x * 1.0 / test_cases_per_run for x in top1_matched_cases_per_run]
|
|
result["diff_pass_rate"] = passed_test_cases * 1.0 / total_test_cases
|
|
result["nan_rate"] = (total_test_cases - len(max_abs_diff_list)) * 1.0 / total_test_cases
|
|
|
|
logger.info(
|
|
f"Parity Test Cases={total_test_cases}; Passed={passed_test_cases}; Nan={total_test_cases-len(max_abs_diff_list)}; Top1_Matched={top1_matched_cases}"
|
|
)
|
|
|
|
if passed_test_cases > 0.95 * total_test_cases:
|
|
logger.info(f"Parity is good: passed rate={int(passed_test_cases*100/total_test_cases):.0f}%")
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def test_performance(ort_session,
|
|
model,
|
|
device,
|
|
is_float16=False,
|
|
total_runs=100,
|
|
use_io_binding=True,
|
|
model_class="GPT2LMHeadModel",
|
|
has_position_ids=True,
|
|
has_attention_mask=True,
|
|
batch_size=8,
|
|
sequence_length=1,
|
|
past_sequence_length=32):
|
|
""" Generate random inputs and measure average latency of Onnx Runtime.
|
|
"""
|
|
|
|
config: GPT2Config = model.config
|
|
|
|
output_buffers = None
|
|
if use_io_binding:
|
|
output_shapes = Gpt2Helper.get_output_shapes(batch_size, past_sequence_length, sequence_length, config,
|
|
model_class)
|
|
output_buffers = Gpt2Helper.get_output_buffers(output_shapes, device, is_float16)
|
|
|
|
dummy_inputs = Gpt2Helper.get_dummy_inputs(batch_size, past_sequence_length, sequence_length,
|
|
config.num_attention_heads, config.hidden_size, config.n_layer,
|
|
config.vocab_size, device, is_float16, has_position_ids,
|
|
has_attention_mask)
|
|
|
|
if use_io_binding:
|
|
_, latency = Gpt2Helper.onnxruntime_inference(ort_session, dummy_inputs, total_runs)
|
|
else:
|
|
_, latency = Gpt2Helper.onnxruntime_inference_with_binded_io(ort_session, dummy_inputs, output_buffers,
|
|
output_shapes, total_runs)
|
|
|
|
return latency
|
|
|
|
@staticmethod
|
|
def torchscript(model, config, device, has_position_ids=True, has_attention_mask=True):
|
|
""" JIT trace for TorchScript.
|
|
"""
|
|
input_list = Gpt2Helper.get_dummy_inputs(batch_size=1,
|
|
past_sequence_length=1,
|
|
sequence_length=1,
|
|
num_attention_heads=config.num_attention_heads,
|
|
hidden_size=config.hidden_size,
|
|
num_layer=config.n_layer,
|
|
vocab_size=config.vocab_size,
|
|
device=device,
|
|
float16=False,
|
|
has_position_ids=has_position_ids,
|
|
has_attention_mask=has_attention_mask).to_list()
|
|
return torch.jit.trace(model, input_list)
|
|
|
|
@staticmethod
|
|
def get_onnx_paths(output_dir,
|
|
model_name_or_path,
|
|
model_class: str = 'GPT2LMHeadModel',
|
|
has_past=True,
|
|
new_folder=False,
|
|
remove_existing=["raw", "fp32", "fp16", "int8"]):
|
|
""" Build a path name for given model based on given attributes.
|
|
"""
|
|
model_name = model_name_or_path
|
|
if not re.match(r'^[\w_-]+$', model_name_or_path): # It is not a name, shall be a path
|
|
assert os.path.isdir(model_name_or_path)
|
|
model_name = Path(model_name_or_path).parts[-1]
|
|
|
|
if model_class != 'GPT2LMHeadModel':
|
|
model_name += "_" + model_class
|
|
|
|
if has_past:
|
|
model_name += "_past"
|
|
|
|
if new_folder:
|
|
suffix = {"raw": "", "fp32": "_fp32", "fp16": "_fp16", "int8": "_int8"}
|
|
# Remove the directories if existed.
|
|
for model_type in ["raw", "fp32", "fp16", "int8"]:
|
|
new_dir = os.path.join(output_dir, model_name + suffix[model_type])
|
|
if os.path.exists(new_dir):
|
|
if (model_type in remove_existing):
|
|
try:
|
|
shutil.rmtree(new_dir)
|
|
logger.info(f"Removed the existed directory: {new_dir}")
|
|
except OSError as e:
|
|
logger.info(f"Failed to remove the directory {new_dir}: {e.strerror}")
|
|
else:
|
|
logger.info(f"Directory for {model_type} existed: {new_dir}")
|
|
|
|
# store each model to its own directory (for external data format).
|
|
return {
|
|
"raw": os.path.join(os.path.join(output_dir, model_name), model_name + ".onnx"),
|
|
"fp32": os.path.join(os.path.join(output_dir, model_name + "_fp32"), model_name + "_fp32.onnx"),
|
|
"fp16": os.path.join(os.path.join(output_dir, model_name + "_fp16"), model_name + "_fp16.onnx"),
|
|
"int8": os.path.join(os.path.join(output_dir, model_name + "_int8"), model_name + "_int8.onnx")
|
|
}
|
|
|
|
return {
|
|
"raw": os.path.join(output_dir, model_name + ".onnx"),
|
|
"fp32": os.path.join(output_dir, model_name + "_fp32.onnx"),
|
|
"fp16": os.path.join(output_dir, model_name + "_fp16.onnx"),
|
|
"int8": os.path.join(output_dir, model_name + "_int8.onnx")
|
|
}
|