Misc transformer fixes - 3 (#14320)

This commit is contained in:
Hariharan Seshadri 2023-01-20 13:57:57 -08:00 committed by GitHub
parent 72821a6113
commit 2d8ee5251c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 12 deletions

View file

@ -103,6 +103,9 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
} else if (attribute_name == "init_decoder") {
ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
// TODO (hasesh): If 'init_decoder' is present, then we update 'parameters_' again based on its subgraph (it would have been
// updated once for the 'decoder' attribute). In future, find a way to update 'parameters' only once based on only one subgraph
// attribute.
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
subgraph_session_state, parameters_);

View file

@ -119,6 +119,9 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
} else if (attribute_name == "init_decoder") {
ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
// TODO (hasesh): If 'init_decoder' is present, then we update 'parameters_' again based on its subgraph (it would have been
// updated once for the 'decoder' attribute). In future, find a way to update 'parameters' only once based on only one subgraph
// attribute.
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
subgraph_session_state, parameters_);
@ -171,8 +174,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
if (has_init_decoder_) {
ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_
&& init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_,
ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_,
"past_present_share_buffer mode must be same for init decoder and decoder subgraphes");
}

View file

@ -178,22 +178,22 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
output_group.set_defaults(run_shape_inference=False)
output_group.add_argument(
"-pvs",
"--pad_vocab_size",
"-dpvs",
"--disable_pad_vocab_size",
required=False,
action="store_true",
help="Pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is the vocab size",
help="Do not pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is the vocab size. The logits MatMul may hence be of poor performance for fp16 precision.",
)
output_group.set_defaults(pad_vocab_size=True)
output_group.set_defaults(disable_pad_vocab_size=False)
output_group.add_argument(
"-sgd",
"--separate_gpt2_decoder_for_init_run",
"-dsgd",
"--disable_separate_gpt2_decoder_for_init_run",
required=False,
action="store_true",
help="Have separate decoder subgraphs for initial and remaining runs. This allows for optimizations based on sequence lengths in each subgraph",
help="Do not create separate decoder subgraphs for initial and remaining runs. This does not allow for optimizations based on sequence lengths in each subgraph",
)
output_group.set_defaults(separate_gpt2_decoder_for_init_run=True)
output_group.set_defaults(disable_separate_gpt2_decoder_for_init_run=False)
output_group.add_argument(
"-i",
@ -1411,7 +1411,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
# This can be expanded to other models/decoding strategies later
logits_matmul_weight_padded = False
if (
args.pad_vocab_size
not args.disable_pad_vocab_size
and args.precision == Precision.FLOAT16
and is_gpt2
and (is_beamsearch or is_greedysearch or is_sampling)
@ -1428,7 +1428,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
gpt2_init_decoder_generated = False
gpt2_init_decoder_onnx_path = None
if args.separate_gpt2_decoder_for_init_run and is_gpt2 and (is_beamsearch or is_greedysearch or is_sampling):
if (
not args.disable_separate_gpt2_decoder_for_init_run
and is_gpt2
and (is_beamsearch or is_greedysearch or is_sampling)
):
logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ")
gpt2_init_decoder_onnx_filename = "gpt2_init_past_{}.onnx".format(

View file

@ -9,6 +9,7 @@
import os
import unittest
import onnx
import pytest
import torch
from parity_utilities import find_transformers_source
@ -59,6 +60,19 @@ class TestBeamSearchGpt(unittest.TestCase):
if os.path.exists(self.beam_search_onnx_path):
os.remove(self.beam_search_onnx_path)
def check_for_init_decoder_attr(self, model_path: str):
init_decoder_found = False
gpt2_beam_search_onnx_model = onnx.load(model_path)
graph_proto = gpt2_beam_search_onnx_model.graph
for node in graph_proto.node:
if node.op_type == "BeamSearch" or node.op_type == "GreedySearch":
for attr in node.attribute:
if attr.name == "init_decoder":
init_decoder_found = True
break
self.assertTrue(init_decoder_found)
def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True, is_greedy=False):
if append_arguments:
@ -74,6 +88,8 @@ class TestBeamSearchGpt(unittest.TestCase):
# Test CPU
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}")
# (CPU) Check for the presence of the "init_decoder" attribute
self.check_for_init_decoder_attr(self.beam_search_onnx_path)
# Test GPU
if self.enable_cuda:
@ -82,6 +98,9 @@ class TestBeamSearchGpt(unittest.TestCase):
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}")
# (GPU) Check for the presence of the "init_decoder" attribute
self.check_for_init_decoder_attr(self.beam_search_onnx_path)
os.remove(self.beam_search_onnx_path)
@pytest.mark.slow