diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f517a0afee..187bb13015 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -350,12 +350,16 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
-
body : graph (required)
-
The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output
+
decoder : graph (required)
+
Decoder subgraph to execute in a loop.
early_stopping : int
early stop or not
+
encoder_decoder_init : graph
+
subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
eos_token_id : int (required)
The id of the end-of-sequence token
+
model_type : int
+
model type: 0 for GPT-2; 1 for encoder decoder like T5
no_repeat_ngram_size : int
no repeat ngrams size
pad_token_id : int (required)
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 09d647bd55..d82f1a7094 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -256,35 +256,41 @@ class BeamSearchImpl { }; void BeamSearch::Init(const OpKernelInfo& info) { - // Make sure the body attribute was present even though we don't need it here. + // Make sure the decoder attribute was present even though we don't need it here. ONNX_NAMESPACE::GraphProto proto; - ORT_ENFORCE(info.GetAttr("body", &proto).IsOK()); + ORT_ENFORCE(info.GetAttr("decoder", &proto).IsOK()); ORT_IGNORE_RETURN_VALUE(proto); parameters_.ParseFromAttributes(info); - - cuda_stream_ = nullptr; } Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); - const auto& node = Node(); - gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer()); - ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state)); - feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); - parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size, - gpt_subgraph_->num_heads, - gpt_subgraph_->head_size, - gpt_subgraph_->num_layers); + // TODO: handle another subgraph with attribute name "encoder_decode_init" + if (attribute_name == "decoder") { + const auto& node = Node(); + gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer()); + ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state)); + feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); + parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size, + gpt_subgraph_->num_heads, + gpt_subgraph_->head_size, + gpt_subgraph_->num_layers); + } return Status::OK(); } Status BeamSearch::Compute(OpKernelContext* ctx) const { + if (parameters_.model_type != 0) { + // TODO: support encoder decoder model like T5 + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Support of 'model_type' != 0 is not implemented"); + } + auto* ctx_internal = static_cast(ctx); - auto* session_state = ctx_internal->SubgraphSessionState("body"); - ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'body' attribute."); + auto* session_state = ctx_internal->SubgraphSessionState("decoder"); + ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'decoder' attribute."); ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 9e2e64498e..217dc080cc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -19,7 +19,8 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel class BeamSearch : public IControlFlowKernel { public: - BeamSearch(const OpKernelInfo& info) : IControlFlowKernel(info), cuda_stream_(nullptr), dumper_(nullptr) { + BeamSearch(const OpKernelInfo& info) + : IControlFlowKernel(info), feeds_fetches_manager_(nullptr), cuda_stream_(nullptr), dumper_(nullptr) { Init(info); } @@ -87,7 +88,7 @@ class BeamSearch : public IControlFlowKernel { void* cuda_stream_; IConsoleDumper* dumper_; - + BeamSearchParameters parameters_; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 236bfab68a..4fc9f2f383 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -17,6 +17,7 @@ Status BeamSearchParameters::Validate() const { } void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { + model_type = static_cast(info.GetAttrOrDefault("model_type", 0)); early_stopping = info.GetAttrOrDefault("early_stopping", 0) == 1; eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1)); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h index 9ca095eadf..3cf4a48f55 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h @@ -73,6 +73,7 @@ class IBeamScorer { struct IBeamSearchParameters { // Parameters from node attributes + int model_type; int eos_token_id; int pad_token_id; int no_repeat_ngram_size; diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index 3b8497c32b..b8d871f320 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -457,9 +457,8 @@ Status UpdateFeeds( for (size_t i = 1; i < last_outputs.size(); ++i) { next_inputs[i + 2] = last_outputs[i]; } - return Status::OK(); } else { - return PickPastState(last_outputs, next_inputs, beam_indices, allocator, stream); + ORT_RETURN_IF_ERROR(PickPastState(last_outputs, next_inputs, beam_indices, allocator, stream)); } // Make sure data is ready before next subgraph execution. diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index de208dc2b9..a3f2fa6cf5 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -919,10 +919,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) - .Attr( - "body", - "The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output", - AttributeProto::GRAPH) + .Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5", AttributeProto::INT, static_cast(0)) + .Attr("encoder_decoder_init", "subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) @@ -951,8 +950,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, BeamSearchShapeInference(ctx); })); - - ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, OpSchema() .Input(0, "X", "input", "T") diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 738f6f33dd..8efef8c15f 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -1,5 +1,18 @@ +#------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +#------------------------------------------------------------------------- + +""" +This converts GPT2 or T5 model to onnx with beam search operator. + +Example 1: convert gpt2 model with beam search: + python convert_beam_search.py -m gpt2 --decoder_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores + +Example 2: convert T5 model with beam search: + python ./models/t5/convert_to_onnx.py -m t5-small -s + python convert_beam_search.py -m t5-small --model_type t5 --decoder_onnx ./onnx_models/t5-small_decoder.onnx --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder_decoder_init.onnx --output ./onnx_models/t5_small_beam_search.onnx +""" import os import time @@ -9,21 +22,17 @@ import argparse from pathlib import Path from onnx import helper import numpy as np -from typing import List +from typing import List, Union import torch -from transformers import GPT2Config +from packaging import version +from transformers import GPT2Config, T5Config from gpt2_helper import PRETRAINED_GPT2_MODELS from convert_to_onnx import main as convert_gpt2_to_onnx from benchmark_helper import Precision from onnx import onnx_pb as onnx_proto -""" -This converts GPT2 model to onnx with beam search operator. -Examples: - python convert_beam_search.py -m gpt2 --gpt2_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores -""" -config: GPT2Config = None +config: Union[GPT2Config, T5Config] = None logger = logging.getLogger('') @@ -37,16 +46,29 @@ def parse_arguments(argv=None): type=str, help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS)) + parser.add_argument('--model_type', + required=False, + type=str, + default="gpt2", + choices=["gpt2", "t5"], + help='Model type in the list: ' + ', '.join(["gpt2", "t5"])) + parser.add_argument('--cache_dir', required=False, type=str, default=os.path.join('.', 'cache_models'), help='Directory to cache pre-trained models') - parser.add_argument('--gpt2_onnx', + parser.add_argument('--decoder_onnx', required=True, type=str, - help='Output directory for GPT-2 onnx model, or model path ends with .onnx') + help='Output directory for decoder onnx model, or model path ends with .onnx') + + parser.add_argument('--encoder_decoder_init_onnx', + required=False, + type=str, + default="", + help='path of ONNX model for encoder and decoder initialization. Required for t5 model type.') parser.add_argument('--output', required=False, @@ -153,12 +175,12 @@ def parse_arguments(argv=None): def gpt2_to_onnx(args): model_name = args.model_name_or_path - print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.gpt2_onnx} ...") + print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.decoder_onnx} ...") arguments = [ '--model_name_or_path', model_name, '--output', - args.gpt2_onnx, + args.decoder_onnx, '--optimize_onnx', '--precision', 'fp32' if args.precision == Precision.FLOAT32 else 'fp16', @@ -182,13 +204,16 @@ def gpt2_to_onnx(args): convert_gpt2_to_onnx(arguments) -def shape_inference(gpt2_onnx_path): +def shape_inference(decoder_onnx_path): + if version.parse(onnx.__version__) >= version.parse('1.11.0'): + logger.warn("SymbolicShapeInference might fail using onnx version 1.11. Please install 1.10.0 for now.") + # Run symbolic shape inference to walk around ORT shape inference issue for subgraph. from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - out = SymbolicShapeInference.infer_shapes(onnx.load(gpt2_onnx_path), auto_merge=True, guess_output_rank=False) + out = SymbolicShapeInference.infer_shapes(onnx.load(decoder_onnx_path), auto_merge=True, guess_output_rank=False) if out: # TODO: Use external format if input has extra data. - onnx.save(out, gpt2_onnx_path) + onnx.save(out, decoder_onnx_path) else: print("Failed to run symbolic shape inference on the model.") @@ -224,7 +249,7 @@ def verify_gpt2_subgraph(graph, precision): expected_type = onnx_proto.TensorProto.INT32 if i >= 3: - expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32 + expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT if graph.input[i].type.tensor_type.elem_type != expected_type: raise ValueError( @@ -240,7 +265,7 @@ def verify_gpt2_subgraph(graph, precision): if graph.output[i].name != expected_output: raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") - expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32 + expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT if graph.output[i].type.tensor_type.elem_type != expected_type: raise ValueError( f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.output[i].type.tensor_type.elem_type}" @@ -251,17 +276,35 @@ def verify_gpt2_subgraph(graph, precision): return +def verify_t5_decoder_subgraph(graph, precision): + # TODO: implement it + pass + + +def verify_t5_encoder_decoder_init_subgraph(graph, precision): + # TODO: implement it + pass + + def convert_model(args): - if os.path.exists(args.gpt2_onnx): - print(f"skip convert_to_onnx since path existed: {args.gpt2_onnx}") + if os.path.exists(args.decoder_onnx): + print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}") else: + assert args.model_type == "gpt2", "please have onnx model ready for model type that is not gpt2" gpt2_to_onnx(args) - print(f"Run symbolic shape inference on {args.gpt2_onnx}. The file will be overwritten.") - shape_inference(args.gpt2_onnx) - + # TODO: fix shape inference for T5. Currently symbolic shape inference on T5 is broken. + enable_shape_inference = args.model_type == "gpt2" + + if enable_shape_inference: + print(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.") + shape_inference(args.decoder_onnx) + global config - config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + if args.model_type == "gpt2": + config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + else: + config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) print(config) eos_token_id = config.eos_token_id @@ -272,10 +315,14 @@ def convert_model(args): if args.vocab_size != -1: vocab_size = args.vocab_size - model = onnx.load(args.gpt2_onnx) - verify_gpt2_subgraph(model.graph, args.precision) + model = onnx.load(args.decoder_onnx) + model.graph.name = f"{args.model_type} decoder subgraph" + + if args.model_type == "gpt2": + verify_gpt2_subgraph(model.graph, args.precision) + else: + verify_t5_decoder_subgraph(model.graph, args.precision) - model.graph.name = "gpt2 subgraph" inputs = [ "input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty", "repetition_penalty", "vocab_mask" @@ -291,16 +338,28 @@ def convert_model(args): assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores" outputs.append("scores") - node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name='BeamSearch_GPT2') + node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name=f'BeamSearch_{args.model_type}') node.domain = "com.microsoft" node.attribute.extend([ helper.make_attribute("eos_token_id", eos_token_id), helper.make_attribute("pad_token_id", pad_token_id), helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), - helper.make_attribute("body", model.graph), + helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), + helper.make_attribute("decoder", model.graph), ]) + if args.model_type == "t5": + if enable_shape_inference: + print(f"Run symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.") + shape_inference(args.encoder_decoder_init_onnx) + init_model = onnx.load(args.encoder_decoder_init_onnx) + init_model.graph.name = f"{args.model_type} encoder decoder init subgraph" + verify_t5_encoder_decoder_init_subgraph(init_model.graph, args.precision) + node.attribute.extend([ + helper.make_attribute("encoder_decoder_init", init_model.graph), + ]) + from onnx import TensorProto # graph inputs @@ -344,7 +403,7 @@ def convert_model(args): if args.output_token_scores: graph_outputs.append(scores) - new_graph = helper.make_graph([node], 'gpt2-beam-search', graph_inputs, graph_outputs, initializers) + new_graph = helper.make_graph([node], f'{args.model_type}-beam-search', graph_inputs, graph_outputs, initializers) # Create the model new_model = helper.make_model(new_graph, producer_name='onnxruntime.transformers', opset_imports=model.opset_import) @@ -392,10 +451,20 @@ def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id, def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): + if args.model_type != "gpt2": + print( + f"Skipping parity test since the support for model type {args.model_type} is not implemented in OnnxRuntime" + ) + return True + + if args.temperature != 1.0: + # TODO: implement temperature in BeamSearch operator. + print("Skipping parity test as temperature is not implemented in BeamSearch operator") + return True if args.prefix_vocab_mask: print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face") - return + return True from transformers import GPT2Tokenizer, GPT2LMHeadModel @@ -547,6 +616,8 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): def main(argv=None, sentences=None): args = parse_arguments(argv) + if args.model_type == "t5": + assert args.encoder_decoder_init_onnx, "please export t5 to onnx models before using this tool" if os.path.exists(args.output): print(f"skip conversion since path existed: {args.output}") diff --git a/onnxruntime/python/tools/transformers/models/__init__.py b/onnxruntime/python/tools/transformers/models/__init__.py new file mode 100644 index 0000000000..d31e7ad45e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/__init__.py @@ -0,0 +1,8 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +#------------------------------------------------------------------------- +import os +import sys + +sys.path.append(os.path.dirname(__file__)) diff --git a/onnxruntime/python/tools/transformers/models/t5/__init__.py b/onnxruntime/python/tools/transformers/models/t5/__init__.py index ad5632855c..d31e7ad45e 100644 --- a/onnxruntime/python/tools/transformers/models/t5/__init__.py +++ b/onnxruntime/python/tools/transformers/models/t5/__init__.py @@ -1,3 +1,7 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +#------------------------------------------------------------------------- import os import sys diff --git a/onnxruntime/test/python/transformers/test_beam_search.py b/onnxruntime/test/python/transformers/test_beam_search.py index f469cdaaaa..64e5762637 100644 --- a/onnxruntime/test/python/transformers/test_beam_search.py +++ b/onnxruntime/test/python/transformers/test_beam_search.py @@ -18,12 +18,13 @@ else: class TestBeamSearch(unittest.TestCase): + def setUp(self): #TODO: use a smaller model and enable tests in CI pipeline self.model_name = "gpt2" self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx') - self.cpu_params = f'-m {self.model_name} --gpt2_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0' + self.cpu_params = f'-m {self.model_name} --decoder_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0' def run_beam_search(self, arguments: str, sentences=None): return run(arguments.split(), sentences=sentences)