mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
check gpt-2 graph in converting beam search (#10712)
This commit is contained in:
parent
d07a2377b1
commit
fa9090f259
1 changed files with 46 additions and 44 deletions
|
|
@ -15,6 +15,7 @@ from transformers import GPT2Config
|
|||
from gpt2_helper import PRETRAINED_GPT2_MODELS
|
||||
from convert_to_onnx import main as convert_gpt2_to_onnx
|
||||
from benchmark_helper import Precision
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
"""
|
||||
This converts GPT2 model to onnx with beam search operator.
|
||||
|
||||
|
|
@ -144,34 +145,6 @@ def parse_arguments(argv=None):
|
|||
help="This vocab mask applies only to first iteration, enable if last word in query might need auto complete")
|
||||
beam_search_group.set_defaults(prefix_vocab_mask=False)
|
||||
|
||||
mixed_precision_option_group = parser.add_argument_group(
|
||||
"mixed precision conversion parameters that works when \"--precision fp16\" is specified")
|
||||
|
||||
mixed_precision_option_group.add_argument('--io_block_list',
|
||||
nargs='+',
|
||||
required=False,
|
||||
default=[],
|
||||
help='List of inputs or outputs in float32')
|
||||
|
||||
mixed_precision_option_group.add_argument(
|
||||
'--op_block_list',
|
||||
nargs='+',
|
||||
required=False,
|
||||
default=[],
|
||||
help='List of operators (like Add LayerNormalization FastGelu) to compute in float32.')
|
||||
|
||||
mixed_precision_option_group.add_argument('--node_block_list',
|
||||
nargs='+',
|
||||
required=False,
|
||||
default=[],
|
||||
help='List of node names to compute in float32.')
|
||||
|
||||
mixed_precision_option_group.add_argument('--force_fp16_initializers',
|
||||
required=False,
|
||||
action='store_true',
|
||||
help='Convert all float initializers to float16.')
|
||||
mixed_precision_option_group.set_defaults(force_fp16_initializers=False)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
return args
|
||||
|
|
@ -200,21 +173,12 @@ def gpt2_to_onnx(args):
|
|||
if args.use_external_data_format:
|
||||
arguments.append('--use_external_data_format')
|
||||
|
||||
# mixed precision conversion options
|
||||
if args.precision == Precision.FLOAT16:
|
||||
assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu"
|
||||
if args.io_block_list:
|
||||
arguments.append('--io_block_list')
|
||||
arguments.extend(args.io_block_list)
|
||||
if args.op_block_list:
|
||||
arguments.append('--op_block_list')
|
||||
arguments.extend(args.op_block_list)
|
||||
if args.node_block_list:
|
||||
arguments.append('--node_block_list')
|
||||
arguments.extend(args.node_block_list)
|
||||
if args.force_fp16_initializers:
|
||||
arguments.append('--force_fp16_initializers')
|
||||
|
||||
# TODO: Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision')
|
||||
# Need change cuda kernel to support a combination of fp32 logits and fp16 past state.
|
||||
# Currently logits and past state shall be same data type.
|
||||
arguments.extend(['--op_block_list', 'Add', 'LayerNormalization', 'FastGelu'])
|
||||
convert_gpt2_to_onnx(arguments)
|
||||
|
||||
|
||||
|
|
@ -244,8 +208,46 @@ def create_ort_session(model_path, use_gpu):
|
|||
return ort_session
|
||||
|
||||
|
||||
def verify_gpt2_subgraph(graph):
|
||||
#TODO: verify names, data types and shapes of inputs and outputs.
|
||||
def verify_gpt2_subgraph(graph, precision):
|
||||
is_float16 = (Precision.FLOAT16 == precision)
|
||||
|
||||
input_count = len(graph.input)
|
||||
layer_count = input_count - 3
|
||||
|
||||
expected_inputs = ['input_ids', 'position_ids', 'attention_mask'] + [f"past_{i}" for i in range(layer_count)]
|
||||
if len(graph.input) != len(expected_inputs):
|
||||
raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
|
||||
|
||||
for i, expected_input in enumerate(expected_inputs):
|
||||
if graph.input[i].name != expected_input:
|
||||
raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
|
||||
|
||||
expected_type = onnx_proto.TensorProto.INT32
|
||||
if i >= 3:
|
||||
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32
|
||||
|
||||
if graph.input[i].type.tensor_type.elem_type != expected_type:
|
||||
raise ValueError(
|
||||
f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.input[i].type.tensor_type.elem_type}"
|
||||
)
|
||||
print("Verifying GPT-2 graph inputs: name and data type are good.")
|
||||
|
||||
expected_outputs = ['logits'] + [f"present_{i}" for i in range(layer_count)]
|
||||
if len(graph.output) != len(expected_outputs):
|
||||
raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
|
||||
|
||||
for i, expected_output in enumerate(expected_outputs):
|
||||
if graph.output[i].name != expected_output:
|
||||
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
|
||||
|
||||
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32
|
||||
if graph.output[i].type.tensor_type.elem_type != expected_type:
|
||||
raise ValueError(
|
||||
f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.output[i].type.tensor_type.elem_type}"
|
||||
)
|
||||
print("Verifying GPT-2 graph outputs: name and data type are good.")
|
||||
|
||||
# TODO: verify shapes of inputs and outputs.
|
||||
return
|
||||
|
||||
|
||||
|
|
@ -271,7 +273,7 @@ def convert_model(args):
|
|||
vocab_size = args.vocab_size
|
||||
|
||||
model = onnx.load(args.gpt2_onnx)
|
||||
verify_gpt2_subgraph(model.graph)
|
||||
verify_gpt2_subgraph(model.graph, args.precision)
|
||||
|
||||
model.graph.name = "gpt2 subgraph"
|
||||
inputs = [
|
||||
|
|
|
|||
Loading…
Reference in a new issue