diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 868356f70a..8744b0244e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -103,6 +103,9 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); } else if (attribute_name == "init_decoder") { ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + // TODO (hasesh): If 'init_decoder' is present, then we update 'parameters_' again based on its subgraph (it would have been + // updated once for the 'decoder' attribute). In future, find a way to update 'parameters' only once based on only one subgraph + // attribute. auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, subgraph_session_state, parameters_); diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc index a33d03738e..e1701ab53c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc @@ -119,6 +119,9 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); } else if (attribute_name == "init_decoder") { ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + // TODO (hasesh): If 'init_decoder' is present, then we update 'parameters_' again based on its subgraph (it would have been + // updated once for the 'decoder' attribute). In future, find a way to update 'parameters' only once based on only one subgraph + // attribute. auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, subgraph_session_state, parameters_); @@ -171,8 +174,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const { if (has_init_decoder_) { ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute."); ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); - ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ - && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_, + ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_, "past_present_share_buffer mode must be same for init decoder and decoder subgraphes"); } diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index ed9f84cf0b..122a574064 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -178,22 +178,22 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: output_group.set_defaults(run_shape_inference=False) output_group.add_argument( - "-pvs", - "--pad_vocab_size", + "-dpvs", + "--disable_pad_vocab_size", required=False, action="store_true", - help="Pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is the vocab size", + help="Do not pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is the vocab size. The logits MatMul may hence be of poor performance for fp16 precision.", ) - output_group.set_defaults(pad_vocab_size=True) + output_group.set_defaults(disable_pad_vocab_size=False) output_group.add_argument( - "-sgd", - "--separate_gpt2_decoder_for_init_run", + "-dsgd", + "--disable_separate_gpt2_decoder_for_init_run", required=False, action="store_true", - help="Have separate decoder subgraphs for initial and remaining runs. This allows for optimizations based on sequence lengths in each subgraph", + help="Do not create separate decoder subgraphs for initial and remaining runs. This does not allow for optimizations based on sequence lengths in each subgraph", ) - output_group.set_defaults(separate_gpt2_decoder_for_init_run=True) + output_group.set_defaults(disable_separate_gpt2_decoder_for_init_run=False) output_group.add_argument( "-i", @@ -1411,7 +1411,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati # This can be expanded to other models/decoding strategies later logits_matmul_weight_padded = False if ( - args.pad_vocab_size + not args.disable_pad_vocab_size and args.precision == Precision.FLOAT16 and is_gpt2 and (is_beamsearch or is_greedysearch or is_sampling) @@ -1428,7 +1428,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati gpt2_init_decoder_generated = False gpt2_init_decoder_onnx_path = None - if args.separate_gpt2_decoder_for_init_run and is_gpt2 and (is_beamsearch or is_greedysearch or is_sampling): + if ( + not args.disable_separate_gpt2_decoder_for_init_run + and is_gpt2 + and (is_beamsearch or is_greedysearch or is_sampling) + ): logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ") gpt2_init_decoder_onnx_filename = "gpt2_init_past_{}.onnx".format( diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index 9f91ee612d..5f790c7025 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -9,6 +9,7 @@ import os import unittest +import onnx import pytest import torch from parity_utilities import find_transformers_source @@ -59,6 +60,19 @@ class TestBeamSearchGpt(unittest.TestCase): if os.path.exists(self.beam_search_onnx_path): os.remove(self.beam_search_onnx_path) + def check_for_init_decoder_attr(self, model_path: str): + init_decoder_found = False + gpt2_beam_search_onnx_model = onnx.load(model_path) + graph_proto = gpt2_beam_search_onnx_model.graph + for node in graph_proto.node: + if node.op_type == "BeamSearch" or node.op_type == "GreedySearch": + for attr in node.attribute: + if attr.name == "init_decoder": + init_decoder_found = True + break + + self.assertTrue(init_decoder_found) + def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True, is_greedy=False): if append_arguments: @@ -74,6 +88,8 @@ class TestBeamSearchGpt(unittest.TestCase): # Test CPU result = run(arguments, sentences=self.sentences if sentences is None else sentences) self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}") + # (CPU) Check for the presence of the "init_decoder" attribute + self.check_for_init_decoder_attr(self.beam_search_onnx_path) # Test GPU if self.enable_cuda: @@ -82,6 +98,9 @@ class TestBeamSearchGpt(unittest.TestCase): result = run(arguments, sentences=self.sentences if sentences is None else sentences) self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}") + # (GPU) Check for the presence of the "init_decoder" attribute + self.check_for_init_decoder_attr(self.beam_search_onnx_path) + os.remove(self.beam_search_onnx_path) @pytest.mark.slow