Update BeamSearch operator spec to support t5 (#10777)

* change BeamSearch op to support encoder decoder model

* check model_type and decoder attribute

* fix

* update comments

* warn shape inference issue with onnx v1.11 or T5

* skip parity test when tempature != 1.0

* fix build
This commit is contained in:
Tianlei Wu 2022-03-04 21:52:45 -08:00 committed by GitHub
parent 6be5185088
commit 0e335aba37
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 150 additions and 57 deletions

View file

@ -350,12 +350,16 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
<dl>
<dt><tt>body</tt> : graph (required)</dt>
<dd>The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output</dd>
<dt><tt>decoder</tt> : graph (required)</dt>
<dd>Decoder subgraph to execute in a loop.</dd>
<dt><tt>early_stopping</tt> : int</dt>
<dd>early stop or not</dd>
<dt><tt>encoder_decoder_init</tt> : graph</dt>
<dd>subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
<dt><tt>eos_token_id</tt> : int (required)</dt>
<dd>The id of the end-of-sequence token</dd>
<dt><tt>model_type</tt> : int</dt>
<dd>model type: 0 for GPT-2; 1 for encoder decoder like T5</dd>
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
<dd>no repeat ngrams size</dd>
<dt><tt>pad_token_id</tt> : int (required)</dt>

View file

@ -256,35 +256,41 @@ class BeamSearchImpl {
};
void BeamSearch::Init(const OpKernelInfo& info) {
// Make sure the body attribute was present even though we don't need it here.
// Make sure the decoder attribute was present even though we don't need it here.
ONNX_NAMESPACE::GraphProto proto;
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("body", &proto).IsOK());
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("decoder", &proto).IsOK());
ORT_IGNORE_RETURN_VALUE(proto);
parameters_.ParseFromAttributes(info);
cuda_stream_ = nullptr;
}
Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
const auto& node = Node();
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
gpt_subgraph_->num_heads,
gpt_subgraph_->head_size,
gpt_subgraph_->num_layers);
// TODO: handle another subgraph with attribute name "encoder_decode_init"
if (attribute_name == "decoder") {
const auto& node = Node();
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
gpt_subgraph_->num_heads,
gpt_subgraph_->head_size,
gpt_subgraph_->num_layers);
}
return Status::OK();
}
Status BeamSearch::Compute(OpKernelContext* ctx) const {
if (parameters_.model_type != 0) {
// TODO: support encoder decoder model like T5
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Support of 'model_type' != 0 is not implemented");
}
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
auto* session_state = ctx_internal->SubgraphSessionState("body");
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'body' attribute.");
auto* session_state = ctx_internal->SubgraphSessionState("decoder");
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

View file

@ -19,7 +19,8 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
class BeamSearch : public IControlFlowKernel {
public:
BeamSearch(const OpKernelInfo& info) : IControlFlowKernel(info), cuda_stream_(nullptr), dumper_(nullptr) {
BeamSearch(const OpKernelInfo& info)
: IControlFlowKernel(info), feeds_fetches_manager_(nullptr), cuda_stream_(nullptr), dumper_(nullptr) {
Init(info);
}
@ -87,7 +88,7 @@ class BeamSearch : public IControlFlowKernel {
void* cuda_stream_;
IConsoleDumper* dumper_;
BeamSearchParameters parameters_;
};

View file

@ -17,6 +17,7 @@ Status BeamSearchParameters::Validate() const {
}
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", 0));
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));

View file

@ -73,6 +73,7 @@ class IBeamScorer {
struct IBeamSearchParameters {
// Parameters from node attributes
int model_type;
int eos_token_id;
int pad_token_id;
int no_repeat_ngram_size;

View file

@ -457,9 +457,8 @@ Status UpdateFeeds(
for (size_t i = 1; i < last_outputs.size(); ++i) {
next_inputs[i + 2] = last_outputs[i];
}
return Status::OK();
} else {
return PickPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream);
ORT_RETURN_IF_ERROR(PickPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream));
}
// Make sure data is ready before next subgraph execution.

View file

@ -919,10 +919,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
.Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
.Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast<int64_t>(0))
.Attr(
"body",
"The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output",
AttributeProto::GRAPH)
.Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("encoder_decoder_init", "subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
.Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH)
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
@ -951,8 +950,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
BeamSearchShapeInference(ctx);
}));
ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1,
OpSchema()
.Input(0, "X", "input", "T")

View file

@ -1,5 +1,18 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#-------------------------------------------------------------------------
"""
This converts GPT2 or T5 model to onnx with beam search operator.
Example 1: convert gpt2 model with beam search:
python convert_beam_search.py -m gpt2 --decoder_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
Example 2: convert T5 model with beam search:
python ./models/t5/convert_to_onnx.py -m t5-small -s
python convert_beam_search.py -m t5-small --model_type t5 --decoder_onnx ./onnx_models/t5-small_decoder.onnx --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder_decoder_init.onnx --output ./onnx_models/t5_small_beam_search.onnx
"""
import os
import time
@ -9,21 +22,17 @@ import argparse
from pathlib import Path
from onnx import helper
import numpy as np
from typing import List
from typing import List, Union
import torch
from transformers import GPT2Config
from packaging import version
from transformers import GPT2Config, T5Config
from gpt2_helper import PRETRAINED_GPT2_MODELS
from convert_to_onnx import main as convert_gpt2_to_onnx
from benchmark_helper import Precision
from onnx import onnx_pb as onnx_proto
"""
This converts GPT2 model to onnx with beam search operator.
Examples:
python convert_beam_search.py -m gpt2 --gpt2_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
"""
config: GPT2Config = None
config: Union[GPT2Config, T5Config] = None
logger = logging.getLogger('')
@ -37,16 +46,29 @@ def parse_arguments(argv=None):
type=str,
help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS))
parser.add_argument('--model_type',
required=False,
type=str,
default="gpt2",
choices=["gpt2", "t5"],
help='Model type in the list: ' + ', '.join(["gpt2", "t5"]))
parser.add_argument('--cache_dir',
required=False,
type=str,
default=os.path.join('.', 'cache_models'),
help='Directory to cache pre-trained models')
parser.add_argument('--gpt2_onnx',
parser.add_argument('--decoder_onnx',
required=True,
type=str,
help='Output directory for GPT-2 onnx model, or model path ends with .onnx')
help='Output directory for decoder onnx model, or model path ends with .onnx')
parser.add_argument('--encoder_decoder_init_onnx',
required=False,
type=str,
default="",
help='path of ONNX model for encoder and decoder initialization. Required for t5 model type.')
parser.add_argument('--output',
required=False,
@ -153,12 +175,12 @@ def parse_arguments(argv=None):
def gpt2_to_onnx(args):
model_name = args.model_name_or_path
print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.gpt2_onnx} ...")
print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.decoder_onnx} ...")
arguments = [
'--model_name_or_path',
model_name,
'--output',
args.gpt2_onnx,
args.decoder_onnx,
'--optimize_onnx',
'--precision',
'fp32' if args.precision == Precision.FLOAT32 else 'fp16',
@ -182,13 +204,16 @@ def gpt2_to_onnx(args):
convert_gpt2_to_onnx(arguments)
def shape_inference(gpt2_onnx_path):
def shape_inference(decoder_onnx_path):
if version.parse(onnx.__version__) >= version.parse('1.11.0'):
logger.warn("SymbolicShapeInference might fail using onnx version 1.11. Please install 1.10.0 for now.")
# Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
out = SymbolicShapeInference.infer_shapes(onnx.load(gpt2_onnx_path), auto_merge=True, guess_output_rank=False)
out = SymbolicShapeInference.infer_shapes(onnx.load(decoder_onnx_path), auto_merge=True, guess_output_rank=False)
if out:
# TODO: Use external format if input has extra data.
onnx.save(out, gpt2_onnx_path)
onnx.save(out, decoder_onnx_path)
else:
print("Failed to run symbolic shape inference on the model.")
@ -224,7 +249,7 @@ def verify_gpt2_subgraph(graph, precision):
expected_type = onnx_proto.TensorProto.INT32
if i >= 3:
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
if graph.input[i].type.tensor_type.elem_type != expected_type:
raise ValueError(
@ -240,7 +265,7 @@ def verify_gpt2_subgraph(graph, precision):
if graph.output[i].name != expected_output:
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
if graph.output[i].type.tensor_type.elem_type != expected_type:
raise ValueError(
f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.output[i].type.tensor_type.elem_type}"
@ -251,17 +276,35 @@ def verify_gpt2_subgraph(graph, precision):
return
def verify_t5_decoder_subgraph(graph, precision):
# TODO: implement it
pass
def verify_t5_encoder_decoder_init_subgraph(graph, precision):
# TODO: implement it
pass
def convert_model(args):
if os.path.exists(args.gpt2_onnx):
print(f"skip convert_to_onnx since path existed: {args.gpt2_onnx}")
if os.path.exists(args.decoder_onnx):
print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
else:
assert args.model_type == "gpt2", "please have onnx model ready for model type that is not gpt2"
gpt2_to_onnx(args)
print(f"Run symbolic shape inference on {args.gpt2_onnx}. The file will be overwritten.")
shape_inference(args.gpt2_onnx)
# TODO: fix shape inference for T5. Currently symbolic shape inference on T5 is broken.
enable_shape_inference = args.model_type == "gpt2"
if enable_shape_inference:
print(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.")
shape_inference(args.decoder_onnx)
global config
config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
if args.model_type == "gpt2":
config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
else:
config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
print(config)
eos_token_id = config.eos_token_id
@ -272,10 +315,14 @@ def convert_model(args):
if args.vocab_size != -1:
vocab_size = args.vocab_size
model = onnx.load(args.gpt2_onnx)
verify_gpt2_subgraph(model.graph, args.precision)
model = onnx.load(args.decoder_onnx)
model.graph.name = f"{args.model_type} decoder subgraph"
if args.model_type == "gpt2":
verify_gpt2_subgraph(model.graph, args.precision)
else:
verify_t5_decoder_subgraph(model.graph, args.precision)
model.graph.name = "gpt2 subgraph"
inputs = [
"input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty",
"repetition_penalty", "vocab_mask"
@ -291,16 +338,28 @@ def convert_model(args):
assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
outputs.append("scores")
node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name='BeamSearch_GPT2')
node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name=f'BeamSearch_{args.model_type}')
node.domain = "com.microsoft"
node.attribute.extend([
helper.make_attribute("eos_token_id", eos_token_id),
helper.make_attribute("pad_token_id", pad_token_id),
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
helper.make_attribute("body", model.graph),
helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
helper.make_attribute("decoder", model.graph),
])
if args.model_type == "t5":
if enable_shape_inference:
print(f"Run symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.")
shape_inference(args.encoder_decoder_init_onnx)
init_model = onnx.load(args.encoder_decoder_init_onnx)
init_model.graph.name = f"{args.model_type} encoder decoder init subgraph"
verify_t5_encoder_decoder_init_subgraph(init_model.graph, args.precision)
node.attribute.extend([
helper.make_attribute("encoder_decoder_init", init_model.graph),
])
from onnx import TensorProto
# graph inputs
@ -344,7 +403,7 @@ def convert_model(args):
if args.output_token_scores:
graph_outputs.append(scores)
new_graph = helper.make_graph([node], 'gpt2-beam-search', graph_inputs, graph_outputs, initializers)
new_graph = helper.make_graph([node], f'{args.model_type}-beam-search', graph_inputs, graph_outputs, initializers)
# Create the model
new_model = helper.make_model(new_graph, producer_name='onnxruntime.transformers', opset_imports=model.opset_import)
@ -392,10 +451,20 @@ def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id,
def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
if args.model_type != "gpt2":
print(
f"Skipping parity test since the support for model type {args.model_type} is not implemented in OnnxRuntime"
)
return True
if args.temperature != 1.0:
# TODO: implement temperature in BeamSearch operator.
print("Skipping parity test as temperature is not implemented in BeamSearch operator")
return True
if args.prefix_vocab_mask:
print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
return
return True
from transformers import GPT2Tokenizer, GPT2LMHeadModel
@ -547,6 +616,8 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
def main(argv=None, sentences=None):
args = parse_arguments(argv)
if args.model_type == "t5":
assert args.encoder_decoder_init_onnx, "please export t5 to onnx models before using this tool"
if os.path.exists(args.output):
print(f"skip conversion since path existed: {args.output}")

View file

@ -0,0 +1,8 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#-------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))

View file

@ -1,3 +1,7 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#-------------------------------------------------------------------------
import os
import sys

View file

@ -18,12 +18,13 @@ else:
class TestBeamSearch(unittest.TestCase):
def setUp(self):
#TODO: use a smaller model and enable tests in CI pipeline
self.model_name = "gpt2"
self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx')
self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx')
self.cpu_params = f'-m {self.model_name} --gpt2_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0'
self.cpu_params = f'-m {self.model_name} --decoder_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0'
def run_beam_search(self, arguments: str, sentences=None):
return run(arguments.split(), sentences=sentences)