mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Misc transformer fixes - 3 (#14320)
This commit is contained in:
parent
72821a6113
commit
2d8ee5251c
4 changed files with 40 additions and 12 deletions
|
|
@ -103,6 +103,9 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
|
|||
decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
|
||||
} else if (attribute_name == "init_decoder") {
|
||||
ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
// TODO (hasesh): If 'init_decoder' is present, then we update 'parameters_' again based on its subgraph (it would have been
|
||||
// updated once for the 'decoder' attribute). In future, find a way to update 'parameters' only once based on only one subgraph
|
||||
// attribute.
|
||||
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
|
||||
subgraph_session_state, parameters_);
|
||||
|
||||
|
|
|
|||
|
|
@ -119,6 +119,9 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
|
|||
decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
|
||||
} else if (attribute_name == "init_decoder") {
|
||||
ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
// TODO (hasesh): If 'init_decoder' is present, then we update 'parameters_' again based on its subgraph (it would have been
|
||||
// updated once for the 'decoder' attribute). In future, find a way to update 'parameters' only once based on only one subgraph
|
||||
// attribute.
|
||||
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
|
||||
subgraph_session_state, parameters_);
|
||||
|
||||
|
|
@ -171,8 +174,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
|
|||
if (has_init_decoder_) {
|
||||
ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
|
||||
ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
|
||||
ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_
|
||||
&& init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_,
|
||||
ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_,
|
||||
"past_present_share_buffer mode must be same for init decoder and decoder subgraphes");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -178,22 +178,22 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
|
|||
output_group.set_defaults(run_shape_inference=False)
|
||||
|
||||
output_group.add_argument(
|
||||
"-pvs",
|
||||
"--pad_vocab_size",
|
||||
"-dpvs",
|
||||
"--disable_pad_vocab_size",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is the vocab size",
|
||||
help="Do not pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is the vocab size. The logits MatMul may hence be of poor performance for fp16 precision.",
|
||||
)
|
||||
output_group.set_defaults(pad_vocab_size=True)
|
||||
output_group.set_defaults(disable_pad_vocab_size=False)
|
||||
|
||||
output_group.add_argument(
|
||||
"-sgd",
|
||||
"--separate_gpt2_decoder_for_init_run",
|
||||
"-dsgd",
|
||||
"--disable_separate_gpt2_decoder_for_init_run",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Have separate decoder subgraphs for initial and remaining runs. This allows for optimizations based on sequence lengths in each subgraph",
|
||||
help="Do not create separate decoder subgraphs for initial and remaining runs. This does not allow for optimizations based on sequence lengths in each subgraph",
|
||||
)
|
||||
output_group.set_defaults(separate_gpt2_decoder_for_init_run=True)
|
||||
output_group.set_defaults(disable_separate_gpt2_decoder_for_init_run=False)
|
||||
|
||||
output_group.add_argument(
|
||||
"-i",
|
||||
|
|
@ -1411,7 +1411,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
# This can be expanded to other models/decoding strategies later
|
||||
logits_matmul_weight_padded = False
|
||||
if (
|
||||
args.pad_vocab_size
|
||||
not args.disable_pad_vocab_size
|
||||
and args.precision == Precision.FLOAT16
|
||||
and is_gpt2
|
||||
and (is_beamsearch or is_greedysearch or is_sampling)
|
||||
|
|
@ -1428,7 +1428,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
|
||||
gpt2_init_decoder_generated = False
|
||||
gpt2_init_decoder_onnx_path = None
|
||||
if args.separate_gpt2_decoder_for_init_run and is_gpt2 and (is_beamsearch or is_greedysearch or is_sampling):
|
||||
if (
|
||||
not args.disable_separate_gpt2_decoder_for_init_run
|
||||
and is_gpt2
|
||||
and (is_beamsearch or is_greedysearch or is_sampling)
|
||||
):
|
||||
logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ")
|
||||
|
||||
gpt2_init_decoder_onnx_filename = "gpt2_init_past_{}.onnx".format(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
from parity_utilities import find_transformers_source
|
||||
|
|
@ -59,6 +60,19 @@ class TestBeamSearchGpt(unittest.TestCase):
|
|||
if os.path.exists(self.beam_search_onnx_path):
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
|
||||
def check_for_init_decoder_attr(self, model_path: str):
|
||||
init_decoder_found = False
|
||||
gpt2_beam_search_onnx_model = onnx.load(model_path)
|
||||
graph_proto = gpt2_beam_search_onnx_model.graph
|
||||
for node in graph_proto.node:
|
||||
if node.op_type == "BeamSearch" or node.op_type == "GreedySearch":
|
||||
for attr in node.attribute:
|
||||
if attr.name == "init_decoder":
|
||||
init_decoder_found = True
|
||||
break
|
||||
|
||||
self.assertTrue(init_decoder_found)
|
||||
|
||||
def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True, is_greedy=False):
|
||||
|
||||
if append_arguments:
|
||||
|
|
@ -74,6 +88,8 @@ class TestBeamSearchGpt(unittest.TestCase):
|
|||
# Test CPU
|
||||
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
|
||||
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}")
|
||||
# (CPU) Check for the presence of the "init_decoder" attribute
|
||||
self.check_for_init_decoder_attr(self.beam_search_onnx_path)
|
||||
|
||||
# Test GPU
|
||||
if self.enable_cuda:
|
||||
|
|
@ -82,6 +98,9 @@ class TestBeamSearchGpt(unittest.TestCase):
|
|||
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
|
||||
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}")
|
||||
|
||||
# (GPU) Check for the presence of the "init_decoder" attribute
|
||||
self.check_for_init_decoder_attr(self.beam_search_onnx_path)
|
||||
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
|
||||
@pytest.mark.slow
|
||||
|
|
|
|||
Loading…
Reference in a new issue