mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
fixing positions for beam search gpt2 (#12156)
* fixing positions for beam search gpt2 Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
This commit is contained in:
parent
9ebef91a6f
commit
05c31a036d
6 changed files with 79 additions and 83 deletions
|
|
@ -197,7 +197,6 @@ void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
|
|||
|
||||
// T5 does not need position, so next_positions is empty for T5.
|
||||
if (!beam_state->next_positions.empty()) {
|
||||
memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes());
|
||||
gsl::copy(sequence_lengths, beam_state->next_positions);
|
||||
}
|
||||
|
||||
|
|
@ -274,13 +273,13 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
|
||||
gsl::span<T>& next_token_scores = beam_state->next_token_scores;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
SoftmaxCPU<T>(
|
||||
batch_beam_size, // rows
|
||||
vocab_size, // elements per row
|
||||
(input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(),
|
||||
next_token_scores.data(),
|
||||
true,
|
||||
thread_pool));
|
||||
SoftmaxCPU<T>(
|
||||
batch_beam_size, // rows
|
||||
vocab_size, // elements per row
|
||||
(input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(),
|
||||
next_token_scores.data(),
|
||||
true,
|
||||
thread_pool));
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size);
|
||||
|
|
@ -428,12 +427,12 @@ Status UpdateGptFeeds(
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
const transformers::IConsoleDumper* dumper) {
|
||||
int gpt_subgraph_first_present_output_idx) {
|
||||
// last_outputs: logits, present_0, present_1, ...
|
||||
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
|
||||
ORT_UNUSED_PARAMETER(stream);
|
||||
|
|
@ -454,10 +453,12 @@ Status UpdateGptFeeds(
|
|||
}
|
||||
next_inputs[0] = input_ids;
|
||||
|
||||
// Update position IDs
|
||||
int32_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
position_data[i]++;
|
||||
if (increase_position) {
|
||||
// Update position IDs
|
||||
int32_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
position_data[i]++;
|
||||
}
|
||||
}
|
||||
next_inputs[1] = position_ids;
|
||||
|
||||
|
|
@ -477,14 +478,6 @@ Status UpdateGptFeeds(
|
|||
}
|
||||
next_inputs[2] = attention_mask;
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("input_ids", input_ids);
|
||||
dumper->Print("position_ids", position_ids);
|
||||
dumper->Print("attention_mask", attention_mask);
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(dumper);
|
||||
#endif
|
||||
|
||||
// Update past state
|
||||
if (num_beams == 1) {
|
||||
// feed present_* output to past_* inputs one by one
|
||||
|
|
@ -725,12 +718,12 @@ template Status UpdateGptFeeds<float>(
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
int gpt_subgraph_first_present_output_idx);
|
||||
|
||||
template Status UpdateDecoderFeeds<float>(
|
||||
AllocatorPtr allocator,
|
||||
|
|
@ -751,28 +744,28 @@ template Status UpdateDecoderFeeds<float>(
|
|||
template void ExpandInputs<int32_t>(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);
|
||||
|
||||
template Status ExpandBuffer<int32_t>(
|
||||
void* stream,
|
||||
const OrtValue& input,
|
||||
int num_beams,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded,
|
||||
bool only_copy_shape);
|
||||
void* stream,
|
||||
const OrtValue& input,
|
||||
int num_beams,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded,
|
||||
bool only_copy_shape);
|
||||
|
||||
template Status ExpandBuffer<float>(
|
||||
void* stream,
|
||||
const OrtValue& input,
|
||||
int num_beams,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded,
|
||||
bool only_copy_shape);
|
||||
void* stream,
|
||||
const OrtValue& input,
|
||||
int num_beams,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded,
|
||||
bool only_copy_shape);
|
||||
|
||||
template Status ExpandBuffer<MLFloat16>(
|
||||
void* stream,
|
||||
const OrtValue& input,
|
||||
int num_beams,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded,
|
||||
bool only_copy_shape);
|
||||
void* stream,
|
||||
const OrtValue& input,
|
||||
int num_beams,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded,
|
||||
bool only_copy_shape);
|
||||
|
||||
} // namespace BeamSearchCpuDeviceHelper
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -96,12 +96,12 @@ using UpdateGptFeedsFunc = std::function<Status(
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
const transformers::IConsoleDumper* dumper)>;
|
||||
int gpt_subgraph_first_present_output_idx)>;
|
||||
|
||||
// Create encoder inputs (for encoder-decoder model like T5).
|
||||
using CreateEncoderInputsFunc = std::function<Status(
|
||||
|
|
@ -142,7 +142,6 @@ using ExpandBufferFunc = std::function<Status(
|
|||
bool only_copy_shape)>;
|
||||
} // namespace BeamSearchDeviceHelper
|
||||
|
||||
|
||||
// These are CPU specific device helper implementations
|
||||
namespace BeamSearchCpuDeviceHelper {
|
||||
Status TopK(
|
||||
|
|
@ -208,12 +207,12 @@ Status UpdateGptFeeds(
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
int gpt_subgraph_first_present_output_idx);
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Functions for encoder-decoder model like T5
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class BeamSearchGpt : public BeamSearchBase<T> {
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices);
|
||||
|
||||
|
|
@ -93,6 +94,7 @@ Status BeamSearchGpt<T>::UpdateFeeds(
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices) {
|
||||
return update_feeds_func_(this->temp_space_allocator_,
|
||||
|
|
@ -101,12 +103,12 @@ Status BeamSearchGpt<T>::UpdateFeeds(
|
|||
next_inputs,
|
||||
current_length,
|
||||
position_ids,
|
||||
increase_position,
|
||||
beam_next_tokens,
|
||||
beam_indices,
|
||||
this->parameters_->num_beams,
|
||||
gpt_subgraph_.GetFirstPastInputIndex(),
|
||||
gpt_subgraph_.GetFirstPresentOutputIndex(),
|
||||
this->GetConsoleDumper());
|
||||
gpt_subgraph_.GetFirstPresentOutputIndex());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -186,11 +188,7 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager& feeds_fetches_manage
|
|||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
const IConsoleDumper* dumper = this->GetConsoleDumper();
|
||||
dumper->Print("input_ids", feeds[0]);
|
||||
dumper->Print("position_ids", feeds[1]);
|
||||
dumper->Print("attention_mask", feeds[2]);
|
||||
#endif
|
||||
|
||||
// Position ids for all iterations except the first. It uses memory buffer owned by next_positions.
|
||||
OrtValue position_ids;
|
||||
int64_t dims[] = {parameters->BatchBeamSize(), 1};
|
||||
|
|
@ -205,9 +203,19 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager& feeds_fetches_manage
|
|||
int iteration_counter = 0;
|
||||
while (current_length < parameters->max_length) {
|
||||
iteration_counter++;
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
auto cur_len = std::to_string(current_length);
|
||||
dumper->Print("***CurrentLength", cur_len, true);
|
||||
dumper->Print("iteration", iteration_counter, true);
|
||||
|
||||
dumper->Print("input_ids", feeds[0]);
|
||||
dumper->Print("position_ids", feeds[1]);
|
||||
dumper->Print("attention_mask", feeds[2]);
|
||||
for (size_t i = 3; i < feeds.size(); i++) {
|
||||
dumper->Print("past", static_cast<int>(i) - 3, true);
|
||||
dumper->Print("", feeds[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
status = utils::ExecuteSubgraph(this->decoder_session_state_,
|
||||
|
|
@ -241,8 +249,11 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager& feeds_fetches_manage
|
|||
|
||||
// Prepare inputs for next round of subgraph call.
|
||||
if (current_length < parameters->max_length) {
|
||||
// For the first iteration, position_ids is initialized as sequence lengths. We can add it to feeds directly.
|
||||
// For the remaining iterations, we need increase position_ids first, then add it to feeds.
|
||||
bool increase_position = (iteration_counter > 1);
|
||||
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
|
||||
position_ids,
|
||||
position_ids, increase_position,
|
||||
beam_next_tokens.as_span<const int32_t>(),
|
||||
beam_indices.as_span<const int32_t>()));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -258,9 +258,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
|
||||
// The output will be float for consideration of precision and easy integration with remaining parts.
|
||||
float* Y_data = next_token_scores.data();
|
||||
const CudaT* X_data = (input_length == 1 && logits_batch_size == batch_beam_size) ?
|
||||
logits_data :
|
||||
reinterpret_cast<const CudaT*>(next_token_logits.data());
|
||||
bool is_single_token = (input_length == 1 && logits_batch_size == batch_beam_size);
|
||||
const CudaT* X_data = is_single_token ? logits_data : reinterpret_cast<const CudaT*>(next_token_logits.data());
|
||||
|
||||
dispatch_blockwise_softmax_forward<CudaT, float, float, true>(
|
||||
cuda_stream, Y_data, X_data, vocab_size, vocab_size, batch_size * num_beams);
|
||||
|
|
@ -500,12 +499,12 @@ Status UpdateGptFeeds(
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
const transformers::IConsoleDumper* dumper) {
|
||||
int gpt_subgraph_first_present_output_idx) {
|
||||
// Update input_ids with next tokens.
|
||||
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
|
||||
int64_t dims[] = {batch_beam_size, 1};
|
||||
|
|
@ -519,7 +518,7 @@ Status UpdateGptFeeds(
|
|||
next_inputs[0] = input_ids;
|
||||
|
||||
// Update position IDs
|
||||
int32_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
int32_t* position_data = increase_position ? position_ids.GetMutable<Tensor>()->MutableData<int32_t>() : nullptr;
|
||||
next_inputs[1] = position_ids;
|
||||
|
||||
// Update attention mask
|
||||
|
|
@ -538,14 +537,6 @@ Status UpdateGptFeeds(
|
|||
|
||||
next_inputs[2] = attention_mask;
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("input_ids", input_ids);
|
||||
dumper->Print("position_ids", position_ids);
|
||||
dumper->Print("attention_mask", attention_mask);
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(dumper);
|
||||
#endif
|
||||
|
||||
// Update past state
|
||||
if (num_beams == 1) {
|
||||
const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx;
|
||||
|
|
@ -662,12 +653,12 @@ Status ExpandBuffer(void* stream,
|
|||
for (int i = 0; i < batch_size; i++) {
|
||||
for (int j = 0; j < num_beams; j++) {
|
||||
CUDA_RETURN_IF_ERROR(
|
||||
cudaMemcpyAsync(
|
||||
target,
|
||||
input_data + i * chunk_size,
|
||||
sizeof(T) * chunk_size,
|
||||
cudaMemcpyDeviceToDevice,
|
||||
cuda_stream));
|
||||
cudaMemcpyAsync(
|
||||
target,
|
||||
input_data + i * chunk_size,
|
||||
sizeof(T) * chunk_size,
|
||||
cudaMemcpyDeviceToDevice,
|
||||
cuda_stream));
|
||||
target += chunk_size;
|
||||
}
|
||||
}
|
||||
|
|
@ -714,12 +705,12 @@ template Status UpdateGptFeeds<float>(
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
int gpt_subgraph_first_present_output_idx);
|
||||
|
||||
// Float16
|
||||
template void InitBeamState<MLFloat16>(transformers::IBeamSearchState<MLFloat16>* beam_state,
|
||||
|
|
@ -748,12 +739,12 @@ template Status UpdateGptFeeds<MLFloat16>(
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
int gpt_subgraph_first_present_output_idx);
|
||||
|
||||
template Status UpdateDecoderFeeds<float>(
|
||||
AllocatorPtr allocator,
|
||||
|
|
|
|||
|
|
@ -68,12 +68,12 @@ Status UpdateGptFeeds(
|
|||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
int gpt_subgraph_first_present_output_idx);
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Functions for encoder-decoder model like T5
|
||||
|
|
|
|||
|
|
@ -248,9 +248,11 @@ __global__ void UpdateGptInputsKernel(const T* old_mask_data,
|
|||
int j = index % current_length;
|
||||
mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast<T>(1);
|
||||
|
||||
// Update sequence length (or next positions).
|
||||
if (index < batch_beam_size) {
|
||||
next_positions[index]++;
|
||||
if (next_positions != nullptr) {
|
||||
// Update sequence length (or next positions).
|
||||
if (index < batch_beam_size) {
|
||||
next_positions[index]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue