diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f517a0afee..187bb13015 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -350,12 +350,16 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
-- body : graph (required)
-- The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output
+- decoder : graph (required)
+- Decoder subgraph to execute in a loop.
- early_stopping : int
- early stop or not
+- encoder_decoder_init : graph
+- subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
- eos_token_id : int (required)
- The id of the end-of-sequence token
+- model_type : int
+- model type: 0 for GPT-2; 1 for encoder decoder like T5
- no_repeat_ngram_size : int
- no repeat ngrams size
- pad_token_id : int (required)
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
index 09d647bd55..d82f1a7094 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
@@ -256,35 +256,41 @@ class BeamSearchImpl {
};
void BeamSearch::Init(const OpKernelInfo& info) {
- // Make sure the body attribute was present even though we don't need it here.
+ // Make sure the decoder attribute was present even though we don't need it here.
ONNX_NAMESPACE::GraphProto proto;
- ORT_ENFORCE(info.GetAttr("body", &proto).IsOK());
+ ORT_ENFORCE(info.GetAttr("decoder", &proto).IsOK());
ORT_IGNORE_RETURN_VALUE(proto);
parameters_.ParseFromAttributes(info);
-
- cuda_stream_ = nullptr;
}
Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
- const auto& node = Node();
- gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer());
- ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
- feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
- parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
- gpt_subgraph_->num_heads,
- gpt_subgraph_->head_size,
- gpt_subgraph_->num_layers);
+ // TODO: handle another subgraph with attribute name "encoder_decode_init"
+ if (attribute_name == "decoder") {
+ const auto& node = Node();
+ gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer());
+ ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
+ feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
+ parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
+ gpt_subgraph_->num_heads,
+ gpt_subgraph_->head_size,
+ gpt_subgraph_->num_layers);
+ }
return Status::OK();
}
Status BeamSearch::Compute(OpKernelContext* ctx) const {
+ if (parameters_.model_type != 0) {
+ // TODO: support encoder decoder model like T5
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Support of 'model_type' != 0 is not implemented");
+ }
+
auto* ctx_internal = static_cast(ctx);
- auto* session_state = ctx_internal->SubgraphSessionState("body");
- ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'body' attribute.");
+ auto* session_state = ctx_internal->SubgraphSessionState("decoder");
+ ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
index 9e2e64498e..217dc080cc 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
@@ -19,7 +19,8 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
class BeamSearch : public IControlFlowKernel {
public:
- BeamSearch(const OpKernelInfo& info) : IControlFlowKernel(info), cuda_stream_(nullptr), dumper_(nullptr) {
+ BeamSearch(const OpKernelInfo& info)
+ : IControlFlowKernel(info), feeds_fetches_manager_(nullptr), cuda_stream_(nullptr), dumper_(nullptr) {
Init(info);
}
@@ -87,7 +88,7 @@ class BeamSearch : public IControlFlowKernel {
void* cuda_stream_;
IConsoleDumper* dumper_;
-
+
BeamSearchParameters parameters_;
};
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
index 236bfab68a..4fc9f2f383 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -17,6 +17,7 @@ Status BeamSearchParameters::Validate() const {
}
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
+ model_type = static_cast(info.GetAttrOrDefault("model_type", 0));
early_stopping = info.GetAttrOrDefault("early_stopping", 0) == 1;
eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1));
pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1));
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h
index 9ca095eadf..3cf4a48f55 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h
@@ -73,6 +73,7 @@ class IBeamScorer {
struct IBeamSearchParameters {
// Parameters from node attributes
+ int model_type;
int eos_token_id;
int pad_token_id;
int no_repeat_ngram_size;
diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc
index 3b8497c32b..b8d871f320 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc
+++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc
@@ -457,9 +457,8 @@ Status UpdateFeeds(
for (size_t i = 1; i < last_outputs.size(); ++i) {
next_inputs[i + 2] = last_outputs[i];
}
- return Status::OK();
} else {
- return PickPastState(last_outputs, next_inputs, beam_indices, allocator, stream);
+ ORT_RETURN_IF_ERROR(PickPastState(last_outputs, next_inputs, beam_indices, allocator, stream));
}
// Make sure data is ready before next subgraph execution.
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index de208dc2b9..a3f2fa6cf5 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -919,10 +919,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
.Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
.Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0))
.Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0))
- .Attr(
- "body",
- "The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output",
- AttributeProto::GRAPH)
+ .Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5", AttributeProto::INT, static_cast(0))
+ .Attr("encoder_decoder_init", "subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
+ .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH)
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
@@ -951,8 +950,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
BeamSearchShapeInference(ctx);
}));
-
-
ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1,
OpSchema()
.Input(0, "X", "input", "T")
diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py
index 738f6f33dd..8efef8c15f 100644
--- a/onnxruntime/python/tools/transformers/convert_beam_search.py
+++ b/onnxruntime/python/tools/transformers/convert_beam_search.py
@@ -1,5 +1,18 @@
+#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
+#-------------------------------------------------------------------------
+
+"""
+This converts GPT2 or T5 model to onnx with beam search operator.
+
+Example 1: convert gpt2 model with beam search:
+ python convert_beam_search.py -m gpt2 --decoder_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
+
+Example 2: convert T5 model with beam search:
+ python ./models/t5/convert_to_onnx.py -m t5-small -s
+ python convert_beam_search.py -m t5-small --model_type t5 --decoder_onnx ./onnx_models/t5-small_decoder.onnx --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder_decoder_init.onnx --output ./onnx_models/t5_small_beam_search.onnx
+"""
import os
import time
@@ -9,21 +22,17 @@ import argparse
from pathlib import Path
from onnx import helper
import numpy as np
-from typing import List
+from typing import List, Union
import torch
-from transformers import GPT2Config
+from packaging import version
+from transformers import GPT2Config, T5Config
from gpt2_helper import PRETRAINED_GPT2_MODELS
from convert_to_onnx import main as convert_gpt2_to_onnx
from benchmark_helper import Precision
from onnx import onnx_pb as onnx_proto
-"""
-This converts GPT2 model to onnx with beam search operator.
-Examples:
- python convert_beam_search.py -m gpt2 --gpt2_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
-"""
-config: GPT2Config = None
+config: Union[GPT2Config, T5Config] = None
logger = logging.getLogger('')
@@ -37,16 +46,29 @@ def parse_arguments(argv=None):
type=str,
help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS))
+ parser.add_argument('--model_type',
+ required=False,
+ type=str,
+ default="gpt2",
+ choices=["gpt2", "t5"],
+ help='Model type in the list: ' + ', '.join(["gpt2", "t5"]))
+
parser.add_argument('--cache_dir',
required=False,
type=str,
default=os.path.join('.', 'cache_models'),
help='Directory to cache pre-trained models')
- parser.add_argument('--gpt2_onnx',
+ parser.add_argument('--decoder_onnx',
required=True,
type=str,
- help='Output directory for GPT-2 onnx model, or model path ends with .onnx')
+ help='Output directory for decoder onnx model, or model path ends with .onnx')
+
+ parser.add_argument('--encoder_decoder_init_onnx',
+ required=False,
+ type=str,
+ default="",
+ help='path of ONNX model for encoder and decoder initialization. Required for t5 model type.')
parser.add_argument('--output',
required=False,
@@ -153,12 +175,12 @@ def parse_arguments(argv=None):
def gpt2_to_onnx(args):
model_name = args.model_name_or_path
- print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.gpt2_onnx} ...")
+ print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.decoder_onnx} ...")
arguments = [
'--model_name_or_path',
model_name,
'--output',
- args.gpt2_onnx,
+ args.decoder_onnx,
'--optimize_onnx',
'--precision',
'fp32' if args.precision == Precision.FLOAT32 else 'fp16',
@@ -182,13 +204,16 @@ def gpt2_to_onnx(args):
convert_gpt2_to_onnx(arguments)
-def shape_inference(gpt2_onnx_path):
+def shape_inference(decoder_onnx_path):
+ if version.parse(onnx.__version__) >= version.parse('1.11.0'):
+ logger.warn("SymbolicShapeInference might fail using onnx version 1.11. Please install 1.10.0 for now.")
+
# Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
- out = SymbolicShapeInference.infer_shapes(onnx.load(gpt2_onnx_path), auto_merge=True, guess_output_rank=False)
+ out = SymbolicShapeInference.infer_shapes(onnx.load(decoder_onnx_path), auto_merge=True, guess_output_rank=False)
if out:
# TODO: Use external format if input has extra data.
- onnx.save(out, gpt2_onnx_path)
+ onnx.save(out, decoder_onnx_path)
else:
print("Failed to run symbolic shape inference on the model.")
@@ -224,7 +249,7 @@ def verify_gpt2_subgraph(graph, precision):
expected_type = onnx_proto.TensorProto.INT32
if i >= 3:
- expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32
+ expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
if graph.input[i].type.tensor_type.elem_type != expected_type:
raise ValueError(
@@ -240,7 +265,7 @@ def verify_gpt2_subgraph(graph, precision):
if graph.output[i].name != expected_output:
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
- expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32
+ expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
if graph.output[i].type.tensor_type.elem_type != expected_type:
raise ValueError(
f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.output[i].type.tensor_type.elem_type}"
@@ -251,17 +276,35 @@ def verify_gpt2_subgraph(graph, precision):
return
+def verify_t5_decoder_subgraph(graph, precision):
+ # TODO: implement it
+ pass
+
+
+def verify_t5_encoder_decoder_init_subgraph(graph, precision):
+ # TODO: implement it
+ pass
+
+
def convert_model(args):
- if os.path.exists(args.gpt2_onnx):
- print(f"skip convert_to_onnx since path existed: {args.gpt2_onnx}")
+ if os.path.exists(args.decoder_onnx):
+ print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
else:
+ assert args.model_type == "gpt2", "please have onnx model ready for model type that is not gpt2"
gpt2_to_onnx(args)
- print(f"Run symbolic shape inference on {args.gpt2_onnx}. The file will be overwritten.")
- shape_inference(args.gpt2_onnx)
-
+ # TODO: fix shape inference for T5. Currently symbolic shape inference on T5 is broken.
+ enable_shape_inference = args.model_type == "gpt2"
+
+ if enable_shape_inference:
+ print(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.")
+ shape_inference(args.decoder_onnx)
+
global config
- config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
+ if args.model_type == "gpt2":
+ config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
+ else:
+ config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
print(config)
eos_token_id = config.eos_token_id
@@ -272,10 +315,14 @@ def convert_model(args):
if args.vocab_size != -1:
vocab_size = args.vocab_size
- model = onnx.load(args.gpt2_onnx)
- verify_gpt2_subgraph(model.graph, args.precision)
+ model = onnx.load(args.decoder_onnx)
+ model.graph.name = f"{args.model_type} decoder subgraph"
+
+ if args.model_type == "gpt2":
+ verify_gpt2_subgraph(model.graph, args.precision)
+ else:
+ verify_t5_decoder_subgraph(model.graph, args.precision)
- model.graph.name = "gpt2 subgraph"
inputs = [
"input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty",
"repetition_penalty", "vocab_mask"
@@ -291,16 +338,28 @@ def convert_model(args):
assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
outputs.append("scores")
- node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name='BeamSearch_GPT2')
+ node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name=f'BeamSearch_{args.model_type}')
node.domain = "com.microsoft"
node.attribute.extend([
helper.make_attribute("eos_token_id", eos_token_id),
helper.make_attribute("pad_token_id", pad_token_id),
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
- helper.make_attribute("body", model.graph),
+ helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
+ helper.make_attribute("decoder", model.graph),
])
+ if args.model_type == "t5":
+ if enable_shape_inference:
+ print(f"Run symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.")
+ shape_inference(args.encoder_decoder_init_onnx)
+ init_model = onnx.load(args.encoder_decoder_init_onnx)
+ init_model.graph.name = f"{args.model_type} encoder decoder init subgraph"
+ verify_t5_encoder_decoder_init_subgraph(init_model.graph, args.precision)
+ node.attribute.extend([
+ helper.make_attribute("encoder_decoder_init", init_model.graph),
+ ])
+
from onnx import TensorProto
# graph inputs
@@ -344,7 +403,7 @@ def convert_model(args):
if args.output_token_scores:
graph_outputs.append(scores)
- new_graph = helper.make_graph([node], 'gpt2-beam-search', graph_inputs, graph_outputs, initializers)
+ new_graph = helper.make_graph([node], f'{args.model_type}-beam-search', graph_inputs, graph_outputs, initializers)
# Create the model
new_model = helper.make_model(new_graph, producer_name='onnxruntime.transformers', opset_imports=model.opset_import)
@@ -392,10 +451,20 @@ def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id,
def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
+ if args.model_type != "gpt2":
+ print(
+ f"Skipping parity test since the support for model type {args.model_type} is not implemented in OnnxRuntime"
+ )
+ return True
+
+ if args.temperature != 1.0:
+ # TODO: implement temperature in BeamSearch operator.
+ print("Skipping parity test as temperature is not implemented in BeamSearch operator")
+ return True
if args.prefix_vocab_mask:
print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
- return
+ return True
from transformers import GPT2Tokenizer, GPT2LMHeadModel
@@ -547,6 +616,8 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
def main(argv=None, sentences=None):
args = parse_arguments(argv)
+ if args.model_type == "t5":
+ assert args.encoder_decoder_init_onnx, "please export t5 to onnx models before using this tool"
if os.path.exists(args.output):
print(f"skip conversion since path existed: {args.output}")
diff --git a/onnxruntime/python/tools/transformers/models/__init__.py b/onnxruntime/python/tools/transformers/models/__init__.py
new file mode 100644
index 0000000000..d31e7ad45e
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/models/__init__.py
@@ -0,0 +1,8 @@
+#-------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+#-------------------------------------------------------------------------
+import os
+import sys
+
+sys.path.append(os.path.dirname(__file__))
diff --git a/onnxruntime/python/tools/transformers/models/t5/__init__.py b/onnxruntime/python/tools/transformers/models/t5/__init__.py
index ad5632855c..d31e7ad45e 100644
--- a/onnxruntime/python/tools/transformers/models/t5/__init__.py
+++ b/onnxruntime/python/tools/transformers/models/t5/__init__.py
@@ -1,3 +1,7 @@
+#-------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+#-------------------------------------------------------------------------
import os
import sys
diff --git a/onnxruntime/test/python/transformers/test_beam_search.py b/onnxruntime/test/python/transformers/test_beam_search.py
index f469cdaaaa..64e5762637 100644
--- a/onnxruntime/test/python/transformers/test_beam_search.py
+++ b/onnxruntime/test/python/transformers/test_beam_search.py
@@ -18,12 +18,13 @@ else:
class TestBeamSearch(unittest.TestCase):
+
def setUp(self):
#TODO: use a smaller model and enable tests in CI pipeline
self.model_name = "gpt2"
self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx')
self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx')
- self.cpu_params = f'-m {self.model_name} --gpt2_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0'
+ self.cpu_params = f'-m {self.model_name} --decoder_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0'
def run_beam_search(self, arguments: str, sentences=None):
return run(arguments.split(), sentences=sentences)