mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Prefix match in first iteration of beam search OP (#10231)
* Add BeamSearch op schema * Add ONNX conversion for beams search * remove attention_mask and change input order * add option to run baseline * add check data type NULL * applies VerifyNodeAndOpMatch to subgraph * update input_ids shape * Add node name for Cast node * expose API for topk * parse parameters * Add beam search scorer * output results * fix typo * use c++ template and format python * fix build pipeline errors * symbolic shape infer of input onnx * output scores * add kernel def hash * Handle vocab_mask; move CheckSubgraph * undo insert_cast_transformer.cc and fusion_utils.py * fix typo * fix merge * update doc * add repetition penalty * refactoring: add GptSubgraph class * move BeamSearchState from .h to .cc file * adjust logits processor order * add batch generation example * fix repetition penalty for dup words in sequence * Add test * Add no repeat ngram processor * refactoring: move logits processor to classes * fix build warning * show latency * use allocator in beam state * use allocator in sequences * fix build error * move next_positions to beam state * Changes for prefix matching * removing debugs * removing more debugs * clean up * clean up * cpu doc updated * Updated docs * updated prefix_vocab_mask dimension in convert script * changes to support bxs prefix_vocab_mask in beamsearchop kernel * doc update * OperatorKernels.md updated * matching docs from artifacts * minor change in logits processor * Addressing comments * Updated the prefix vocab mask usage properly Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
This commit is contained in:
parent
1aa0789691
commit
ad9d2e2e89
8 changed files with 134 additions and 12 deletions
|
|
@ -361,7 +361,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>The id of the padding token</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (6 - 9)
|
||||
#### Inputs (6 - 10)
|
||||
|
||||
<dl>
|
||||
<dt><tt>input_ids</tt> : I</dt>
|
||||
|
|
@ -382,6 +382,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
|
||||
<dt><tt>vocab_mask</tt> (optional) : M</dt>
|
||||
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
|
||||
<dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
|
||||
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs (1 - 3)
|
||||
|
|
|
|||
|
|
@ -377,7 +377,7 @@ Do not modify directly.*
|
|||
|**Operator Domain:** *com.microsoft*||||
|
||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|
||||
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|
||||
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* temperature:**T**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|
||||
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* temperature:**T**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|
||||
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|
||||
|BifurcationDetector|*in* src_tokens:**T**<br> *in* cur_tokens:**T**<br> *in* prev_suffix_match_idx:**T**<br> *in* pred_tokens:**T**<br> *out* tokens:**T**<br> *out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|
||||
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|
||||
|
|
|
|||
|
|
@ -153,12 +153,14 @@ class BeamSearchImpl {
|
|||
Status GenerateNextToken(const OrtValue& logits,
|
||||
gsl::span<int64_t>& beam_next_tokens,
|
||||
gsl::span<int64_t>& beam_indices,
|
||||
BeamSearchState<T>& beam_state);
|
||||
BeamSearchState<T>& beam_state,
|
||||
int counter);
|
||||
|
||||
// Calculate scores from logits, then apply filtering and select next token for each beam.
|
||||
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
|
||||
BeamSearchState<T>& beam_state,
|
||||
AllocatorPtr& allocator);
|
||||
AllocatorPtr& allocator,
|
||||
int counter);
|
||||
|
||||
OpKernelContextInternal& context_;
|
||||
|
||||
|
|
@ -292,6 +294,30 @@ Status BeamSearchImpl<T>::CheckInputs(const OpKernelContextInternal& context) {
|
|||
parameters_->vocab_mask = vocab_mask->DataAsSpan<int32_t>();
|
||||
}
|
||||
|
||||
const Tensor* prefix_vocab_mask = context.Input<Tensor>(9);
|
||||
if (prefix_vocab_mask != nullptr) {
|
||||
// prefix_vocab_mask is optional
|
||||
const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims();
|
||||
if (vocab_mask_dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 2 dimensions, got ",
|
||||
vocab_mask_dims.size());
|
||||
}
|
||||
|
||||
// prefix_vocab_mask first dimension should be same as the first dimension of input_ids
|
||||
if (static_cast<int>(vocab_mask_dims[0]) != static_cast<int>(dims[0])) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input_ids and prefix_vocab_mask must have the same batch_size");
|
||||
}
|
||||
|
||||
// There is dependency on vocab_size parameter, which shall be set before calling this function.
|
||||
if (static_cast<int>(vocab_mask_dims[1]) != parameters_->vocab_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ",
|
||||
vocab_mask_dims[0]);
|
||||
}
|
||||
|
||||
// store prefix vocab mask in parameters.
|
||||
parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan<int32_t>();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -346,7 +372,8 @@ template <typename T>
|
|||
Status BeamSearchImpl<T>::ProcessLogits(
|
||||
const OrtValue& logits,
|
||||
BeamSearchState<T>& beam_state,
|
||||
AllocatorPtr& allocator) {
|
||||
AllocatorPtr& allocator,
|
||||
int counter) {
|
||||
const int64_t batch_beam_size = static_cast<int64_t>(parameters_->BatchBeamSize());
|
||||
const int& vocab_size = parameters_->vocab_size;
|
||||
|
||||
|
|
@ -394,7 +421,7 @@ Status BeamSearchImpl<T>::ProcessLogits(
|
|||
#endif
|
||||
|
||||
// Apply all score processors that updates scores
|
||||
logits_processors_.Process(&(beam_state.sequences), next_token_scores);
|
||||
logits_processors_.Process(&(beam_state.sequences), next_token_scores, counter);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
|
||||
|
|
@ -486,9 +513,10 @@ Status BeamSearchImpl<T>::GenerateNextToken(
|
|||
const OrtValue& logits,
|
||||
gsl::span<int64_t>& beam_next_tokens,
|
||||
gsl::span<int64_t>& beam_indices,
|
||||
BeamSearchState<T>& beam_state) {
|
||||
BeamSearchState<T>& beam_state,
|
||||
int counter) {
|
||||
// Process logits to get next token scores
|
||||
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_));
|
||||
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_, counter));
|
||||
|
||||
gsl::span<T>& beam_scores = beam_scorer_->GetNextScores();
|
||||
// It is optional to clone beam_scores. Change it to use same buffer also works:
|
||||
|
|
@ -587,7 +615,9 @@ Status BeamSearchImpl<T>::Execute(const FeedsFetchesManager& ffm) {
|
|||
#endif
|
||||
|
||||
int current_length = parameters_->sequence_length;
|
||||
int iteration_counter = 0;
|
||||
while (current_length < parameters_->max_length) {
|
||||
iteration_counter++;
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpString("***CurrentLength", std::to_string(current_length), true);
|
||||
#endif
|
||||
|
|
@ -600,7 +630,7 @@ Status BeamSearchImpl<T>::Execute(const FeedsFetchesManager& ffm) {
|
|||
const OrtValue& logits = fetches[0];
|
||||
gsl::span<int64_t> beam_next_tokens;
|
||||
gsl::span<int64_t> beam_indices;
|
||||
ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state));
|
||||
ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, iteration_counter));
|
||||
|
||||
// When all batches are finished, stop earlier to avoid wasting computation.
|
||||
if (beam_scorer_->IsDone()) {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ struct BeamSearchParameters {
|
|||
int sequence_length; // deduce from second dimension of input_ids
|
||||
|
||||
gsl::span<const int32_t> vocab_mask;
|
||||
|
||||
gsl::span<const int32_t> prefix_vocab_mask;
|
||||
|
||||
// Parameters from outputs.
|
||||
bool output_scores; // whether scores existed in output
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
// beam_search_iteration represents the current iteration counter of beam search
|
||||
// This value is used to apply processors as needed in specific iteration.
|
||||
static int beam_search_iteration;
|
||||
|
||||
template <typename T>
|
||||
gsl::span<T> NextTokenScores<T>::GetScores(int batch_beam_index) {
|
||||
assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size);
|
||||
|
|
@ -146,6 +150,41 @@ void VocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
|
|||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
PrefixVocabMaskLogitsProcessor<T>::PrefixVocabMaskLogitsProcessor(const gsl::span<const int32_t>& prefix_vocab_mask, int batch_size) : prefix_vocab_mask_(prefix_vocab_mask), batch_size_(batch_size) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PrefixVocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
|
||||
NextTokenScores<T>& next_token_scores) {
|
||||
assert(!prefix_vocab_mask_.empty());
|
||||
|
||||
if (beam_search_iteration > 1) {
|
||||
return;
|
||||
}
|
||||
// next_token_scores shape (batch_size * num_beams, vocab_size)
|
||||
int num_beams = next_token_scores.batch_beam_size / batch_size_;
|
||||
assert(num_beams * batch_size_ == next_token_scores.batch_beam_size);
|
||||
|
||||
// Process prefix vocabulary mask and set tokens with mask value 0 to -inf.
|
||||
// prefix_vocab_mask shape (batch_szie, vocab_size).
|
||||
T* p = next_token_scores.scores.data();
|
||||
for (int i = 0; i < batch_size_; i++) {
|
||||
int prefix_vocab_mask_offset = i * next_token_scores.vocab_size;
|
||||
for (int j = 0; j < num_beams; j++) {
|
||||
for (int k = 0; k < next_token_scores.vocab_size; k++, p++) {
|
||||
if (prefix_vocab_mask_[prefix_vocab_mask_offset + k] == 0) {
|
||||
*p = std::numeric_limits<T>::lowest();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpScores("PrefixVocabMaskLogitsProcessor", next_token_scores.scores);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LogitsProcessorList<T>::Init(const BeamSearchParameters& parameters) {
|
||||
processor_list_.clear();
|
||||
|
|
@ -165,6 +204,11 @@ void LogitsProcessorList<T>::Init(const BeamSearchParameters& parameters) {
|
|||
processor_list_.push_back(vocab_mask_processor_.get());
|
||||
}
|
||||
|
||||
if (!parameters.prefix_vocab_mask.empty()) {
|
||||
prefix_vocab_mask_processor_ = std::make_unique<PrefixVocabMaskLogitsProcessor<T>>(parameters.prefix_vocab_mask, parameters.batch_size);
|
||||
processor_list_.push_back(prefix_vocab_mask_processor_.get());
|
||||
}
|
||||
|
||||
if (parameters.min_length > 0) {
|
||||
min_length_processor_ = std::make_unique<MinLengthLogitsProcessor<T>>(parameters.min_length, parameters.eos_token_id);
|
||||
processor_list_.push_back(min_length_processor_.get());
|
||||
|
|
@ -176,8 +220,10 @@ void LogitsProcessorList<T>::Init(const BeamSearchParameters& parameters) {
|
|||
|
||||
template <typename T>
|
||||
void LogitsProcessorList<T>::Process(const ISequences* sequences,
|
||||
gsl::span<T>& next_token_scores) {
|
||||
gsl::span<T>& next_token_scores,
|
||||
int counter) {
|
||||
NextTokenScores<T> input_scores = {next_token_scores, batch_beam_size_, vocab_size_};
|
||||
beam_search_iteration = counter;
|
||||
for (size_t i = 0; i < processor_list_.size(); i++) {
|
||||
processor_list_[i]->Process(sequences, input_scores);
|
||||
}
|
||||
|
|
@ -188,6 +234,7 @@ template class MinLengthLogitsProcessor<float>;
|
|||
template class RepetitionPenaltyLogitsProcessor<float>;
|
||||
template class NoRepeatNGramLogitsProcessor<float>;
|
||||
template class VocabMaskLogitsProcessor<float>;
|
||||
template class PrefixVocabMaskLogitsProcessor<float>;
|
||||
template class LogitsProcessorList<float>;
|
||||
|
||||
} // namespace transformers
|
||||
|
|
|
|||
|
|
@ -76,12 +76,25 @@ class VocabMaskLogitsProcessor : public ILogitsProcessor<T> {
|
|||
gsl::span<const int32_t> vocab_mask_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor<T> {
|
||||
public:
|
||||
PrefixVocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask, int batch_size);
|
||||
|
||||
void Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) override;
|
||||
|
||||
private:
|
||||
gsl::span<const int32_t> prefix_vocab_mask_;
|
||||
const int batch_size_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LogitsProcessorList {
|
||||
public:
|
||||
LogitsProcessorList() = default ;
|
||||
void Init(const BeamSearchParameters& parameters);
|
||||
void Process(const ISequences* sequences, gsl::span<T>& next_token_scores);
|
||||
void Process(const ISequences* sequences, gsl::span<T>& next_token_scores, int counter);
|
||||
|
||||
private:
|
||||
int batch_beam_size_;
|
||||
|
|
@ -91,6 +104,7 @@ private:
|
|||
std::unique_ptr<RepetitionPenaltyLogitsProcessor<T>> repetition_penalty_processor_;
|
||||
std::unique_ptr<NoRepeatNGramLogitsProcessor<T>> no_repeat_ngram_processor_;
|
||||
std::unique_ptr<VocabMaskLogitsProcessor<T>> vocab_mask_processor_;
|
||||
std::unique_ptr<PrefixVocabMaskLogitsProcessor<T>> prefix_vocab_mask_processor_;
|
||||
std::unique_ptr<MinLengthLogitsProcessor<T>> min_length_processor_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -695,6 +695,7 @@ void RegisterTextGenerationSchemas() {
|
|||
"T", OpSchema::Optional)
|
||||
.Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
|
||||
.Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional)
|
||||
.Input(9, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
|
||||
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
|
||||
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
|
||||
.Output(2, "scores",
|
||||
|
|
|
|||
|
|
@ -128,6 +128,18 @@ def parse_arguments(argv=None):
|
|||
default=1,
|
||||
help='Positive. >1 to penalize and <1 to encorage.')
|
||||
|
||||
beam_search_group.add_argument('--vocab_size',
|
||||
type=int,
|
||||
required=False,
|
||||
default=-1,
|
||||
help="Vocab_size of the underlying model")
|
||||
|
||||
beam_search_group.add_argument('--prefix_vocab_mask',
|
||||
required=False,
|
||||
action='store_true',
|
||||
help="This vocab mask applies only to first iteration, enable if last word in query might need auto complete")
|
||||
beam_search_group.set_defaults(prefix_vocab_mask=False)
|
||||
|
||||
mixed_precision_option_group = parser.add_argument_group(
|
||||
"mixed precision conversion parameters that works when \"--precision fp16\" is specified")
|
||||
|
||||
|
|
@ -230,12 +242,18 @@ def convert_model(args):
|
|||
pad_token_id = config.eos_token_id
|
||||
vocab_size = config.vocab_size
|
||||
|
||||
# if vocab_size is given in parameters use that.
|
||||
if args.vocab_size != -1:
|
||||
vocab_size = args.vocab_size
|
||||
|
||||
model = onnx.load(args.gpt2_onnx)
|
||||
model.graph.name = "gpt2 subgraph"
|
||||
inputs = [
|
||||
"input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty",
|
||||
"repetition_penalty", "vocab_mask"
|
||||
]
|
||||
if args.prefix_vocab_mask:
|
||||
inputs.append("prefix_vocab_mask")
|
||||
|
||||
outputs = ["sequences"]
|
||||
if args.output_sequences_scores:
|
||||
|
|
@ -273,6 +291,10 @@ def convert_model(args):
|
|||
repetition_penalty, vocab_mask
|
||||
]
|
||||
|
||||
if args.prefix_vocab_mask:
|
||||
prefix_vocab_mask = helper.make_tensor_value_info('prefix_vocab_mask', TensorProto.INT32, ['batch_size', vocab_size])
|
||||
graph_inputs.append(prefix_vocab_mask)
|
||||
|
||||
# graph outputs
|
||||
sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32,
|
||||
['batch_size', 'num_return_sequences', 'max_length'])
|
||||
|
|
@ -301,6 +323,11 @@ def convert_model(args):
|
|||
|
||||
|
||||
def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
|
||||
|
||||
if args.prefix_vocab_mask:
|
||||
print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
|
||||
return
|
||||
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
|
|
|
|||
Loading…
Reference in a new issue