Fix benchmark_gpt2 model verification (#5343)

This commit is contained in:
Tianlei Wu 2020-10-02 13:53:02 -07:00 committed by GitHub
parent 6e4949e235
commit f5e4c0ea04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 24 deletions

View file

@ -24,7 +24,7 @@ from benchmark_helper import create_onnxruntime_session, setup_logger, prepare_e
logger = logging.getLogger('')
def parse_arguments():
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument('-m',
@ -101,15 +101,12 @@ def parse_arguments():
parser.add_argument('--verbose', required=False, action='store_true')
parser.set_defaults(verbose=False)
args = parser.parse_args()
args = parser.parse_args(argv)
return args
def main():
args = parse_arguments()
setup_logger(args.verbose)
def main(args):
logger.info(f"Arguments:{args}")
if args.precision == Precision.FLOAT16:
assert args.optimize_onnx and args.use_gpu, "fp16 requires --optimize_onnx --use_gpu"
@ -229,10 +226,13 @@ def main():
f'Pytorch and ONNX Runtime outputs are all close (tolerance={DEFAULT_TOLERANCE[args.precision]}).'
)
for i in ort_io_outputs:
ort_io_outputs[i] = ort_io_outputs[i].cpu().numpy()
# Results of IO binding might be in GPU. Copy outputs to CPU for comparison.
copy_outputs = []
for output in ort_io_outputs:
copy_outputs.append(output.cpu().numpy())
if Gpt2Helper.compare_outputs(outputs,
ort_io_outputs,
copy_outputs,
rtol=DEFAULT_TOLERANCE[args.precision],
atol=DEFAULT_TOLERANCE[args.precision]):
logger.info(
@ -261,7 +261,10 @@ def main():
logger.error(f"Exception", exc_info=True)
logger.info(f"Results are saved to file {csv_filename}")
return csv_filename
if __name__ == '__main__':
main()
args = parse_arguments()
setup_logger(args.verbose)
main(args)

View file

@ -8,6 +8,7 @@ import logging
import numpy
import os
import torch
from pathlib import Path
from transformers import AutoConfig, AutoTokenizer, AutoModel
from benchmark_helper import create_onnxruntime_session, Precision
from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS
@ -133,6 +134,7 @@ def get_onnx_file_path(onnx_dir: str, model_name: str, input_count: int, optimiz
filename += f"_ort"
directory = onnx_dir
# ONNXRuntime will not write external data so the raw and optimized models shall be in same directory.
if use_external_data and not optimized_by_onnxruntime:
directory = os.path.join(onnx_dir, filename)
@ -142,6 +144,18 @@ def get_onnx_file_path(onnx_dir: str, model_name: str, input_count: int, optimiz
return os.path.join(directory, f"{filename}.onnx")
def add_filename_suffix(file_path: str, suffix: str) -> str:
"""
Append a suffix at the filename (before the extension).
Args:
path: pathlib.Path The actual path object we would like to add a suffix
suffix: The suffix to add
Returns: path with suffix appended at the end of the filename and before extension
"""
path = Path(file_path)
return str(path.parent.joinpath(path.stem + suffix).with_suffix(path.suffix))
def optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite, model_fusion_statistics):
if overwrite or not os.path.exists(ort_model_path):
from optimizer import optimize_by_onnxruntime, get_fusion_statistics
@ -286,16 +300,15 @@ def validate_and_optimize_onnx(model_name, use_external_data_format, model_type,
else: # Use OnnxRuntime to optimize
if is_valid_onnx_model:
ort_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), False, use_gpu, precision, True,
use_external_data_format)
ort_model_path = add_filename_suffix(onnx_model_path, '_ort')
optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite, model_fusion_statistics)
return onnx_model_path, is_valid_onnx_model, config.vocab_size
def export_onnx_model_from_pt(model_name, opset_version, use_external_data_format, model_type, model_class, cache_dir, onnx_dir,
input_names, use_gpu, precision, optimize_onnx, validate_onnx, use_raw_attention_mask, overwrite,
model_fusion_statistics):
def export_onnx_model_from_pt(model_name, opset_version, use_external_data_format, model_type, model_class, cache_dir,
onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx,
use_raw_attention_mask, overwrite, model_fusion_statistics):
config, model = load_pt_model(model_name, model_class, cache_dir)
# config, model = load_pt_model_from_tf(model_name)
@ -341,15 +354,16 @@ def export_onnx_model_from_pt(model_name, opset_version, use_external_data_forma
logger.info(f"Skip export since model existed: {onnx_model_path}")
onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx,
use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path, example_inputs, example_outputs_flatten)
model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimize_onnx,
validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path,
example_inputs, example_outputs_flatten)
return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size
def export_onnx_model_from_tf(model_name, opset_version, use_external_data_format, model_type, model_class, cache_dir, onnx_dir,
input_names, use_gpu, precision, optimize_onnx, validate_onnx, use_raw_attention_mask, overwrite,
model_fusion_statistics):
def export_onnx_model_from_tf(model_name, opset_version, use_external_data_format, model_type, model_class, cache_dir,
onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx,
use_raw_attention_mask, overwrite, model_fusion_statistics):
config, model = load_tf_model(model_name, model_class, cache_dir)
@ -359,7 +373,11 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma
max_input_size = tokenizer.max_model_input_sizes[
model_name] if model_name in tokenizer.max_model_input_sizes else 1024
example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="tf", max_length=max_input_size, pad_to_max_length=True, truncation=True)
example_inputs = tokenizer.encode_plus("This is a sample input",
return_tensors="tf",
max_length=max_input_size,
pad_to_max_length=True,
truncation=True)
example_inputs = filter_inputs(example_inputs, input_names)
@ -383,8 +401,8 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma
model_type = model_type + '_keras'
onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimize_onnx, validate_onnx,
use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path, example_inputs, example_outputs_flatten)
model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimize_onnx,
validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path,
example_inputs, example_outputs_flatten)
return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import unittest
import os
import onnxruntime
import logging
import coloredlogs
import pytest
class TestGpt2(unittest.TestCase):
def run_benchmark_gpt2(self, arguments: str):
from benchmark_gpt2 import parse_arguments, main
args = parse_arguments(arguments.split())
csv_filename = main(args)
self.assertTrue(os.path.exists(csv_filename))
def test_gpt2_fp32(self):
self.run_benchmark_gpt2('-m gpt2 --precision fp32 -v -b 1 -s 128')
def test_gpt2_fp16(self):
if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
self.run_benchmark_gpt2('-m gpt2 --precision fp16 -o -b 1 -s 128')
def test_gpt2_int8(self):
self.run_benchmark_gpt2('-m gpt2 --precision int8 -o -b 1 -s 128')
if __name__ == '__main__':
coloredlogs.install(fmt='%(message)s')
logging.getLogger("transformers").setLevel(logging.ERROR)
unittest.main()