fixing positions for beam search gpt2 (#12156)

* fixing positions for beam search gpt2
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
This commit is contained in:
Viswanath Boga 2022-07-14 13:31:59 -07:00 committed by GitHub
parent 9ebef91a6f
commit 05c31a036d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 79 additions and 83 deletions

View file

@ -197,7 +197,6 @@ void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
// T5 does not need position, so next_positions is empty for T5.
if (!beam_state->next_positions.empty()) {
memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes());
gsl::copy(sequence_lengths, beam_state->next_positions);
}
@ -274,13 +273,13 @@ Status ProcessLogits(const OrtValue& logits, //
// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
gsl::span<T>& next_token_scores = beam_state->next_token_scores;
ORT_RETURN_IF_ERROR(
SoftmaxCPU<T>(
batch_beam_size, // rows
vocab_size, // elements per row
(input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(),
next_token_scores.data(),
true,
thread_pool));
SoftmaxCPU<T>(
batch_beam_size, // rows
vocab_size, // elements per row
(input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(),
next_token_scores.data(),
true,
thread_pool));
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size);
@ -428,12 +427,12 @@ Status UpdateGptFeeds(
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
const transformers::IConsoleDumper* dumper) {
int gpt_subgraph_first_present_output_idx) {
// last_outputs: logits, present_0, present_1, ...
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
ORT_UNUSED_PARAMETER(stream);
@ -454,10 +453,12 @@ Status UpdateGptFeeds(
}
next_inputs[0] = input_ids;
// Update position IDs
int32_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_beam_size; i++) {
position_data[i]++;
if (increase_position) {
// Update position IDs
int32_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_beam_size; i++) {
position_data[i]++;
}
}
next_inputs[1] = position_ids;
@ -477,14 +478,6 @@ Status UpdateGptFeeds(
}
next_inputs[2] = attention_mask;
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("input_ids", input_ids);
dumper->Print("position_ids", position_ids);
dumper->Print("attention_mask", attention_mask);
#else
ORT_UNUSED_PARAMETER(dumper);
#endif
// Update past state
if (num_beams == 1) {
// feed present_* output to past_* inputs one by one
@ -725,12 +718,12 @@ template Status UpdateGptFeeds<float>(
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
const transformers::IConsoleDumper* dumper);
int gpt_subgraph_first_present_output_idx);
template Status UpdateDecoderFeeds<float>(
AllocatorPtr allocator,
@ -751,28 +744,28 @@ template Status UpdateDecoderFeeds<float>(
template void ExpandInputs<int32_t>(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);
template Status ExpandBuffer<int32_t>(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);
template Status ExpandBuffer<float>(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);
template Status ExpandBuffer<MLFloat16>(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);
} // namespace BeamSearchCpuDeviceHelper
} // namespace contrib

View file

@ -96,12 +96,12 @@ using UpdateGptFeedsFunc = std::function<Status(
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
const transformers::IConsoleDumper* dumper)>;
int gpt_subgraph_first_present_output_idx)>;
// Create encoder inputs (for encoder-decoder model like T5).
using CreateEncoderInputsFunc = std::function<Status(
@ -142,7 +142,6 @@ using ExpandBufferFunc = std::function<Status(
bool only_copy_shape)>;
} // namespace BeamSearchDeviceHelper
// These are CPU specific device helper implementations
namespace BeamSearchCpuDeviceHelper {
Status TopK(
@ -208,12 +207,12 @@ Status UpdateGptFeeds(
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
const transformers::IConsoleDumper* dumper);
int gpt_subgraph_first_present_output_idx);
// ---------------------------------------------------------------
// Functions for encoder-decoder model like T5

View file

@ -56,6 +56,7 @@ class BeamSearchGpt : public BeamSearchBase<T> {
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices);
@ -93,6 +94,7 @@ Status BeamSearchGpt<T>::UpdateFeeds(
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices) {
return update_feeds_func_(this->temp_space_allocator_,
@ -101,12 +103,12 @@ Status BeamSearchGpt<T>::UpdateFeeds(
next_inputs,
current_length,
position_ids,
increase_position,
beam_next_tokens,
beam_indices,
this->parameters_->num_beams,
gpt_subgraph_.GetFirstPastInputIndex(),
gpt_subgraph_.GetFirstPresentOutputIndex(),
this->GetConsoleDumper());
gpt_subgraph_.GetFirstPresentOutputIndex());
}
template <typename T>
@ -186,11 +188,7 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager& feeds_fetches_manage
#ifdef DEBUG_BEAM_SEARCH
const IConsoleDumper* dumper = this->GetConsoleDumper();
dumper->Print("input_ids", feeds[0]);
dumper->Print("position_ids", feeds[1]);
dumper->Print("attention_mask", feeds[2]);
#endif
// Position ids for all iterations except the first. It uses memory buffer owned by next_positions.
OrtValue position_ids;
int64_t dims[] = {parameters->BatchBeamSize(), 1};
@ -205,9 +203,19 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager& feeds_fetches_manage
int iteration_counter = 0;
while (current_length < parameters->max_length) {
iteration_counter++;
#ifdef DEBUG_BEAM_SEARCH
auto cur_len = std::to_string(current_length);
dumper->Print("***CurrentLength", cur_len, true);
dumper->Print("iteration", iteration_counter, true);
dumper->Print("input_ids", feeds[0]);
dumper->Print("position_ids", feeds[1]);
dumper->Print("attention_mask", feeds[2]);
for (size_t i = 3; i < feeds.size(); i++) {
dumper->Print("past", static_cast<int>(i) - 3, true);
dumper->Print("", feeds[i]);
}
#endif
status = utils::ExecuteSubgraph(this->decoder_session_state_,
@ -241,8 +249,11 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager& feeds_fetches_manage
// Prepare inputs for next round of subgraph call.
if (current_length < parameters->max_length) {
// For the first iteration, position_ids is initialized as sequence lengths. We can add it to feeds directly.
// For the remaining iterations, we need increase position_ids first, then add it to feeds.
bool increase_position = (iteration_counter > 1);
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
position_ids,
position_ids, increase_position,
beam_next_tokens.as_span<const int32_t>(),
beam_indices.as_span<const int32_t>()));
}

View file

@ -258,9 +258,8 @@ Status ProcessLogits(const OrtValue& logits, //
// The output will be float for consideration of precision and easy integration with remaining parts.
float* Y_data = next_token_scores.data();
const CudaT* X_data = (input_length == 1 && logits_batch_size == batch_beam_size) ?
logits_data :
reinterpret_cast<const CudaT*>(next_token_logits.data());
bool is_single_token = (input_length == 1 && logits_batch_size == batch_beam_size);
const CudaT* X_data = is_single_token ? logits_data : reinterpret_cast<const CudaT*>(next_token_logits.data());
dispatch_blockwise_softmax_forward<CudaT, float, float, true>(
cuda_stream, Y_data, X_data, vocab_size, vocab_size, batch_size * num_beams);
@ -500,12 +499,12 @@ Status UpdateGptFeeds(
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
const transformers::IConsoleDumper* dumper) {
int gpt_subgraph_first_present_output_idx) {
// Update input_ids with next tokens.
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
int64_t dims[] = {batch_beam_size, 1};
@ -519,7 +518,7 @@ Status UpdateGptFeeds(
next_inputs[0] = input_ids;
// Update position IDs
int32_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int32_t>();
int32_t* position_data = increase_position ? position_ids.GetMutable<Tensor>()->MutableData<int32_t>() : nullptr;
next_inputs[1] = position_ids;
// Update attention mask
@ -538,14 +537,6 @@ Status UpdateGptFeeds(
next_inputs[2] = attention_mask;
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("input_ids", input_ids);
dumper->Print("position_ids", position_ids);
dumper->Print("attention_mask", attention_mask);
#else
ORT_UNUSED_PARAMETER(dumper);
#endif
// Update past state
if (num_beams == 1) {
const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx;
@ -662,12 +653,12 @@ Status ExpandBuffer(void* stream,
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(
target,
input_data + i * chunk_size,
sizeof(T) * chunk_size,
cudaMemcpyDeviceToDevice,
cuda_stream));
cudaMemcpyAsync(
target,
input_data + i * chunk_size,
sizeof(T) * chunk_size,
cudaMemcpyDeviceToDevice,
cuda_stream));
target += chunk_size;
}
}
@ -714,12 +705,12 @@ template Status UpdateGptFeeds<float>(
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
const transformers::IConsoleDumper* dumper);
int gpt_subgraph_first_present_output_idx);
// Float16
template void InitBeamState<MLFloat16>(transformers::IBeamSearchState<MLFloat16>* beam_state,
@ -748,12 +739,12 @@ template Status UpdateGptFeeds<MLFloat16>(
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
const transformers::IConsoleDumper* dumper);
int gpt_subgraph_first_present_output_idx);
template Status UpdateDecoderFeeds<float>(
AllocatorPtr allocator,

View file

@ -68,12 +68,12 @@ Status UpdateGptFeeds(
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
const transformers::IConsoleDumper* dumper);
int gpt_subgraph_first_present_output_idx);
// ---------------------------------------------------------------
// Functions for encoder-decoder model like T5

View file

@ -248,9 +248,11 @@ __global__ void UpdateGptInputsKernel(const T* old_mask_data,
int j = index % current_length;
mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast<T>(1);
// Update sequence length (or next positions).
if (index < batch_beam_size) {
next_positions[index]++;
if (next_positions != nullptr) {
// Update sequence length (or next positions).
if (index < batch_beam_size) {
next_positions[index]++;
}
}
}
}