From fa9090f2597b1bba6db8fe5c85737f50f9dd0ef8 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 1 Mar 2022 19:04:34 -0800 Subject: [PATCH] check gpt-2 graph in converting beam search (#10712) --- .../tools/transformers/convert_beam_search.py | 90 ++++++++++--------- 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index c292d31908..738f6f33dd 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -15,6 +15,7 @@ from transformers import GPT2Config from gpt2_helper import PRETRAINED_GPT2_MODELS from convert_to_onnx import main as convert_gpt2_to_onnx from benchmark_helper import Precision +from onnx import onnx_pb as onnx_proto """ This converts GPT2 model to onnx with beam search operator. @@ -144,34 +145,6 @@ def parse_arguments(argv=None): help="This vocab mask applies only to first iteration, enable if last word in query might need auto complete") beam_search_group.set_defaults(prefix_vocab_mask=False) - mixed_precision_option_group = parser.add_argument_group( - "mixed precision conversion parameters that works when \"--precision fp16\" is specified") - - mixed_precision_option_group.add_argument('--io_block_list', - nargs='+', - required=False, - default=[], - help='List of inputs or outputs in float32') - - mixed_precision_option_group.add_argument( - '--op_block_list', - nargs='+', - required=False, - default=[], - help='List of operators (like Add LayerNormalization FastGelu) to compute in float32.') - - mixed_precision_option_group.add_argument('--node_block_list', - nargs='+', - required=False, - default=[], - help='List of node names to compute in float32.') - - mixed_precision_option_group.add_argument('--force_fp16_initializers', - required=False, - action='store_true', - help='Convert all float initializers to float16.') - mixed_precision_option_group.set_defaults(force_fp16_initializers=False) - args = parser.parse_args(argv) return args @@ -200,21 +173,12 @@ def gpt2_to_onnx(args): if args.use_external_data_format: arguments.append('--use_external_data_format') - # mixed precision conversion options if args.precision == Precision.FLOAT16: assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu" - if args.io_block_list: - arguments.append('--io_block_list') - arguments.extend(args.io_block_list) - if args.op_block_list: - arguments.append('--op_block_list') - arguments.extend(args.op_block_list) - if args.node_block_list: - arguments.append('--node_block_list') - arguments.extend(args.node_block_list) - if args.force_fp16_initializers: - arguments.append('--force_fp16_initializers') - + # TODO: Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision') + # Need change cuda kernel to support a combination of fp32 logits and fp16 past state. + # Currently logits and past state shall be same data type. + arguments.extend(['--op_block_list', 'Add', 'LayerNormalization', 'FastGelu']) convert_gpt2_to_onnx(arguments) @@ -244,8 +208,46 @@ def create_ort_session(model_path, use_gpu): return ort_session -def verify_gpt2_subgraph(graph): - #TODO: verify names, data types and shapes of inputs and outputs. +def verify_gpt2_subgraph(graph, precision): + is_float16 = (Precision.FLOAT16 == precision) + + input_count = len(graph.input) + layer_count = input_count - 3 + + expected_inputs = ['input_ids', 'position_ids', 'attention_mask'] + [f"past_{i}" for i in range(layer_count)] + if len(graph.input) != len(expected_inputs): + raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") + + for i, expected_input in enumerate(expected_inputs): + if graph.input[i].name != expected_input: + raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}") + + expected_type = onnx_proto.TensorProto.INT32 + if i >= 3: + expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32 + + if graph.input[i].type.tensor_type.elem_type != expected_type: + raise ValueError( + f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.input[i].type.tensor_type.elem_type}" + ) + print("Verifying GPT-2 graph inputs: name and data type are good.") + + expected_outputs = ['logits'] + [f"present_{i}" for i in range(layer_count)] + if len(graph.output) != len(expected_outputs): + raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") + + for i, expected_output in enumerate(expected_outputs): + if graph.output[i].name != expected_output: + raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") + + expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32 + if graph.output[i].type.tensor_type.elem_type != expected_type: + raise ValueError( + f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.output[i].type.tensor_type.elem_type}" + ) + print("Verifying GPT-2 graph outputs: name and data type are good.") + + # TODO: verify shapes of inputs and outputs. return @@ -271,7 +273,7 @@ def convert_model(args): vocab_size = args.vocab_size model = onnx.load(args.gpt2_onnx) - verify_gpt2_subgraph(model.graph) + verify_gpt2_subgraph(model.graph, args.precision) model.graph.name = "gpt2 subgraph" inputs = [