Fix benchmark bugs and add Pytorch version control (#10928)

This commit is contained in:
Ye Wang 2022-03-18 09:24:19 -07:00 committed by GitHub
parent 6f844522c8
commit ee05c591e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 105 additions and 24 deletions

View file

@ -206,7 +206,7 @@ def run_pytorch(use_gpu, model_names, model_class, config_modifier, precision, n
for model_name in model_names:
config = AutoConfig.from_pretrained(model_name, torchscript=torchscript, cache_dir=cache_dir)
config_modifier(config)
config_modifier.modify(config)
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
@ -248,6 +248,7 @@ def run_pytorch(use_gpu, model_names, model_class, config_modifier, precision, n
result = {
"engine": "torchscript" if torchscript else "torch",
"version": torch.__version__,
"providers": "NA",
"device": "cuda" if use_gpu else "cpu",
"optimizer": "",
"precision": precision,
@ -323,7 +324,7 @@ def run_tensorflow(use_gpu, model_names, model_class, config_modifier, precision
for model_name in model_names:
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
config_modifier(config)
config_modifier.modify(config)
model = load_pretrained_model(model_name,
config=config,
@ -381,6 +382,7 @@ def run_tensorflow(use_gpu, model_names, model_class, config_modifier, precision
result = {
"engine": "tensorflow",
"version": tf.__version__,
"providers": "NA",
"device": "cuda" if use_gpu else "cpu",
"optimizer": "",
"precision": precision,

View file

@ -17,6 +17,7 @@ from typing import List, Dict, Tuple, Union
from transformers import GPT2LMHeadModel, GPT2Config
from benchmark_helper import Precision
from gpt2_helper import Gpt2Helper, Gpt2Inputs, GPT2ModelNoPastState, MyGPT2Model, MyGPT2LMHeadModel, MyGPT2LMHeadModel_NoPadding
from torch_onnx_export_helper import torch_onnx_export
logger = logging.getLogger(__name__)
@ -36,7 +37,7 @@ class Gpt2HelperFactory:
class GPT2LMHeadModel_BeamSearchStep(GPT2LMHeadModel):
"""Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and one
"""Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and one
step beam search."""
def __init__(self, config, batch_size, beam_size):
super().__init__(config)
@ -120,7 +121,7 @@ class GPT2LMHeadModel_BeamSearchStep(GPT2LMHeadModel):
class GPT2LMHeadModel_ConfigurableOneStepSearch(GPT2LMHeadModel):
"""Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and one
"""Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and one
step beam search with configuration support."""
def __init__(self,
config,
@ -628,7 +629,7 @@ class Gpt2BeamSearchHelper(Gpt2Helper):
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(
torch_onnx_export(
model,
args=tuple(input_list),
f=onnx_model_path,

View file

@ -21,6 +21,7 @@ from onnx_model import OnnxModel
from fusion_utils import FusionUtils
from benchmark_helper import Precision
from io_binding_helper import IOBindingHelper
from torch_onnx_export_helper import torch_onnx_export
logger = logging.getLogger(__name__)
@ -402,7 +403,7 @@ class Gpt2Helper:
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(model,
torch_onnx_export(model,
args=tuple(input_list),
f=onnx_model_path,
input_names=input_names,

View file

@ -15,6 +15,8 @@
#
# For inference of the onnx model, you will need onnxruntime-gpu 1.7.0 or above.
import sys
import os
import torch
import numpy as np
import argparse
@ -25,6 +27,9 @@ from packaging import version
from pathlib import Path
from longformer_helper import LongformerHelper, PRETRAINED_LONGFORMER_MODELS
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from torch_onnx_export_helper import torch_onnx_export
@parse_args('v', 'v', 'v', 'v', 'v', 'v', 'v', 'i', 'i')
def my_longformer_attention(g, input, weight, bias, mask, global_weight, global_bias, global_mask, num_heads, window):
@ -223,7 +228,7 @@ def export_longformer(model, onnx_model_path, export_padding):
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(model,
torch_onnx_export(model,
example_inputs,
onnx_model_path,
opset_version=11,

View file

@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
class PastKeyValuesHelper:
""" Helper functions to process past key values for encoder-decoder model"""
@staticmethod
def get_past_names(num_layers, present: bool = False):
past_self_names = []

View file

@ -6,6 +6,8 @@
from pathlib import Path
from typing import List, Union
import sys
import os
import logging
import numpy
import torch
@ -14,6 +16,9 @@ from onnxruntime import InferenceSession
from t5_encoder import T5EncoderInputs
from past_helper import PastKeyValuesHelper
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
from torch_onnx_export_helper import torch_onnx_export
logger = logging.getLogger(__name__)
@ -21,7 +26,6 @@ class T5DecoderInit(torch.nn.Module):
""" A T5 decoder with LM head to create initial past key values.
This model is only called once during starting decoding.
"""
def __init__(self,
decoder: torch.nn.Module,
lm_head: torch.nn.Module,
@ -58,7 +62,6 @@ class T5DecoderInit(torch.nn.Module):
class T5Decoder(torch.nn.Module):
""" A T5 decoder with LM head and past key values"""
def __init__(self, decoder, lm_head, config):
super().__init__()
self.decoder = decoder
@ -89,7 +92,6 @@ class T5Decoder(torch.nn.Module):
class T5DecoderInputs:
def __init__(self, decoder_input_ids, encoder_attention_mask, encoder_hidden_states, past_key_values=None):
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
@ -160,7 +162,6 @@ class T5DecoderInputs:
class T5DecoderHelper:
@staticmethod
def export_onnx(decoder: Union[T5Decoder, T5DecoderInit],
device: torch.device,
@ -250,7 +251,7 @@ class T5DecoderHelper:
}
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(decoder,
torch_onnx_export(decoder,
args=tuple(input_list),
f=onnx_model_path,
export_params=True,

View file

@ -5,6 +5,8 @@
# --------------------------------------------------------------------------
import random
import sys
import os
from pathlib import Path
from typing import List
import logging
@ -13,12 +15,14 @@ import torch
from transformers import T5Config
from onnxruntime import InferenceSession
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
from torch_onnx_export_helper import torch_onnx_export
logger = logging.getLogger(__name__)
class T5Encoder(torch.nn.Module):
""" T5 encoder outputs only the last hidden state"""
def __init__(self, encoder, config: T5Config):
super().__init__()
self.encoder = encoder
@ -29,7 +33,6 @@ class T5Encoder(torch.nn.Module):
class T5EncoderInputs:
def __init__(self, input_ids, attention_mask):
self.input_ids: torch.LongTensor = input_ids
self.attention_mask: torch.LongTensor = attention_mask
@ -44,7 +47,7 @@ class T5EncoderInputs:
sequence_length (int): sequence length
vocab_size (int): vocaburary size
device (torch.device): device of output tensors
Returns:
T5EncoderInputs: dummy inputs for encoder
"""
@ -67,7 +70,6 @@ class T5EncoderInputs:
class T5EncoderHelper:
@staticmethod
def export_onnx(encoder: T5Encoder,
device: torch.device,
@ -93,7 +95,7 @@ class T5EncoderHelper:
outputs = encoder(encoder_inputs.input_ids, encoder_inputs.attention_mask)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(encoder,
torch_onnx_export(encoder,
args=tuple(encoder_inputs.to_list()),
f=onnx_model_path,
export_params=True,

View file

@ -6,6 +6,8 @@
from pathlib import Path
from typing import List
import sys
import os
import logging
import numpy
import torch
@ -15,13 +17,15 @@ from t5_encoder import T5Encoder, T5EncoderInputs
from t5_decoder import T5DecoderInit
from past_helper import PastKeyValuesHelper
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
from torch_onnx_export_helper import torch_onnx_export
logger = logging.getLogger(__name__)
class T5EncoderDecoderInit(torch.nn.Module):
""" A combination of T5Encoder and T5DecoderInit.
"""
def __init__(self,
encoder: torch.nn.Module,
decoder: torch.nn.Module,
@ -44,7 +48,6 @@ class T5EncoderDecoderInit(torch.nn.Module):
class T5EncoderDecoderInitInputs:
def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None):
self.encoder_input_ids: torch.LongTensor = encoder_input_ids
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
@ -70,7 +73,6 @@ class T5EncoderDecoderInitInputs:
class T5EncoderDecoderInitHelper:
@staticmethod
def export_onnx(model: T5EncoderDecoderInit,
device: torch.device,
@ -153,7 +155,7 @@ class T5EncoderDecoderInitHelper:
dynamic_axes[name] = {0: 'batch_size', 1: num_heads, 2: sequence_length, 3: head_size}
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(model,
torch_onnx_export(model,
args=tuple(input_list),
f=onnx_model_path,
export_params=True,

View file

@ -22,7 +22,6 @@ PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3B", "t5-11B"]
class T5Helper:
@staticmethod
def get_onnx_path(output_dir: str, model_name_or_path: str, suffix: str = "", new_folder: bool = False) -> str:
"""Build onnx path

View file

@ -15,6 +15,7 @@ from benchmark_helper import create_onnxruntime_session, Precision, OptimizerInf
from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState
from quantize_helper import QuantizeHelper
from huggingface_models import MODEL_CLASSES
from torch_onnx_export_helper import torch_onnx_export
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
@ -386,7 +387,7 @@ def export_onnx_model_from_pt(model_name, opset_version, use_external_data_forma
dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
replace_torch_functions()
torch.onnx.export(model=model,
torch_onnx_export(model=model,
args=tuple(example_inputs.values()),
f=onnx_model_path,
input_names=list(example_inputs.keys()),

View file

@ -0,0 +1,68 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
import torch
TrainingMode = torch.onnx.TrainingMode
from packaging.version import Version
def torch_onnx_export(
model,
args,
f,
export_params=True,
verbose=False,
training=TrainingMode.EVAL,
input_names=None,
output_names=None,
operator_export_type=None,
opset_version=None,
_retain_param_name=None,
do_constant_folding=True,
example_outputs=None,
strip_doc_string=None,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
enable_onnx_checker=None,
use_external_data_format=None,
export_modules_as_functions=False):
if Version(torch.__version__) >= Version("1.11.0"):
torch.onnx.export(
model=model,
args=args,
f=f,
export_params=export_params,
verbose=verbose,
training=training,
input_names=input_names,
output_names=output_names,
operator_export_type=operator_export_type,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets,
export_modules_as_functions=export_modules_as_functions)
else:
torch.onnx.export(
model=model,
args=args,
f=f,
export_params=export_params,
verbose=verbose,
training=training,
input_names=input_names,
output_names=output_names,
operator_export_type=operator_export_type,
opset_version=opset_version,
_retain_param_name=_retain_param_name,
do_constant_folding=do_constant_folding,
example_outputs=example_outputs,
strip_doc_string=strip_doc_string,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets,
enable_onnx_checker=enable_onnx_checker,
use_external_data_format=use_external_data_format)