Improve TopP sampling (#14192)

### Description
Improve TopP sampling's filter kernel with cub::scan. It reduces TopP
sampling latency from 3.67 to 0.92 for batch size 8 and vocabulary size
51k.
This commit is contained in:
Yufeng Li 2023-01-11 08:40:17 -08:00 committed by GitHub
parent d92c663f28
commit 7a9a6bcebd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -8,7 +8,6 @@
#include <cub/device/device_segmented_radix_sort.cuh>
#include "contrib_ops/cuda/transformers/generation_cuda_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
@ -312,14 +311,14 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data,
}
template <typename T>
void GetTempStorageSize(const T *d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes) {
void GetTempStorageSize(const T* d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes) {
if (is_descending) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
temp_storage_bytes,
@ -352,24 +351,24 @@ void GetTempStorageSize(const T *d_keys_in,
}
template void GetTempStorageSize(
const float *d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes);
const float* d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes);
template void GetTempStorageSize(
const half *d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes);
const half* d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes);
// TODO: merge to one kernel
__global__ void SetupParamsKernel(int* d_values_in,
@ -401,17 +400,17 @@ void LaunchSetupParamsKernel(int* d_values_in,
}
template <typename T>
void LaunchSortPairs(void *d_temp_storage,
size_t temp_storage_bytes,
const T *d_keys_in,
T *d_keys_out,
const int *d_values_in,
int *d_values_out,
int num_items,
int num_segments,
int *d_offsets,
cudaStream_t stream,
bool is_descending) {
void LaunchSortPairs(void* d_temp_storage,
size_t temp_storage_bytes,
const T* d_keys_in,
T* d_keys_out,
const int* d_values_in,
int* d_values_out,
int num_items,
int num_segments,
int* d_offsets,
cudaStream_t stream,
bool is_descending) {
if (is_descending) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
temp_storage_bytes,
@ -443,31 +442,46 @@ void LaunchSortPairs(void *d_temp_storage,
}
}
template void LaunchSortPairs(void *d_temp_storage,
template void LaunchSortPairs(void* d_temp_storage,
size_t temp_storage_bytes,
const float *d_keys_in,
float *d_keys_out,
const int *d_values_in,
int *d_values_out,
const float* d_keys_in,
float* d_keys_out,
const int* d_values_in,
int* d_values_out,
int num_items,
int num_segments,
int *d_offsets,
int* d_offsets,
cudaStream_t stream,
bool is_descending);
template void LaunchSortPairs(void *d_temp_storage,
template void LaunchSortPairs(void* d_temp_storage,
size_t temp_storage_bytes,
const half *d_keys_in,
half *d_keys_out,
const int *d_values_in,
int *d_values_out,
const half* d_keys_in,
half* d_keys_out,
const int* d_values_in,
int* d_values_out,
int num_items,
int num_segments,
int *d_offsets,
int* d_offsets,
cudaStream_t stream,
bool is_descending);
template <typename T>
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
struct BlockPrefixCallbackOp {
float running_total; // running prefix
__device__ BlockPrefixCallbackOp(float running_total) : running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ float operator()(float block_aggregate) {
float old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template <typename T, int kBlockSize>
__global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in,
const int* d_sorted_indices,
T* d_logits_in_out,
@ -475,35 +489,27 @@ __global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in,
float filter_value,
int batch_size,
int vocab_size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int vocab_idx = threadIdx.x;
int batch_id = blockIdx.x;
int offset = batch_id * vocab_size;
if (index >= batch_size * vocab_size) {
return;
}
typedef cub::BlockScan<float, kBlockSize> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
BlockPrefixCallbackOp prefix_op(0);
int vocab_idx = index % vocab_size;
int batch_id = index / vocab_size;
int start_index = batch_id * vocab_size;
for (int idx = vocab_idx; idx < vocab_size; idx += kBlockSize) {
float sum = d_sorted_logits_in[offset + idx];
BlockScan(temp_storage).ExclusiveSum(sum, sum, prefix_op);
int count = vocab_idx;
float sum = 0.0f;
while (count >= 0) {
sum += d_sorted_logits_in[start_index];
++start_index;
--count;
}
if (sum > top_p_threshold) {
// Shift the indices to the right by one according to the custom implementation.
int shifted_index = index + 1;
if (shifted_index % vocab_size != 0) {
int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index];
__syncthreads();
if (sum >= top_p_threshold) {
int original_index = offset + d_sorted_indices[offset + idx];
d_logits_in_out[original_index] = (T)filter_value;
}
}
}
template <typename T>
template <typename T, int kBlockSize>
__global__ void FilterLogitsKernel(float* d_sorted_logits_in,
const int* d_sorted_indices,
T* d_logits_in_out,
@ -512,29 +518,25 @@ __global__ void FilterLogitsKernel(float* d_sorted_logits_in,
int min_tokens_to_keep,
int batch_size,
int vocab_size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int vocab_idx = threadIdx.x;
int batch_id = blockIdx.x;
int offset = batch_id * vocab_size;
if (index >= batch_size * vocab_size) {
return;
}
typedef cub::BlockScan<float, kBlockSize> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
BlockPrefixCallbackOp prefix_op(0);
int vocab_idx = index % vocab_size;
int batch_id = index / vocab_size;
int start_index = batch_id * vocab_size;
for (int idx = vocab_idx; idx < vocab_size; idx += kBlockSize) {
float sum = d_sorted_logits_in[offset + idx];
BlockScan(temp_storage).InclusiveSum(sum, sum, prefix_op);
int count = vocab_idx;
float sum = 0.0f;
// TODO: Optimization needed. e.g. use CUB::SCAN() for cumulative probabilities.
while (count >= 0) {
sum += d_sorted_logits_in[start_index];
++start_index;
--count;
}
__syncthreads();
if (sum <= top_p_threshold) {
if (index % vocab_size + min_tokens_to_keep < vocab_size) {
int original_index = batch_id * vocab_size + d_sorted_indices[index];
d_logits_in_out[original_index] = (T)filter_value;
if (sum <= top_p_threshold) {
if (idx + min_tokens_to_keep < vocab_size) {
int original_index = offset + d_sorted_indices[offset + idx];
d_logits_in_out[original_index] = (T)filter_value;
}
}
}
}
@ -550,26 +552,25 @@ void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
int vocab_size,
cudaStream_t stream,
bool is_descending) {
int total_elements = batch_size * vocab_size;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
constexpr int kBlockSize = 256;
if (is_descending) {
FilterLogitsKernelCustom<<<gridSize, blockSize, 0, stream>>>(d_sorted_logits_in,
d_sorted_indices,
d_logits_in_out,
top_p,
filter_value,
batch_size,
vocab_size);
FilterLogitsKernelCustom<T, kBlockSize><<<batch_size, kBlockSize, 0, stream>>>(d_sorted_logits_in,
d_sorted_indices,
d_logits_in_out,
top_p,
filter_value,
batch_size,
vocab_size);
} else {
FilterLogitsKernel<<<gridSize, blockSize, 0, stream>>>(d_sorted_logits_in,
d_sorted_indices,
d_logits_in_out,
1 - top_p,
filter_value,
min_tokens_to_keep,
batch_size,
vocab_size);
FilterLogitsKernel<T, kBlockSize><<<batch_size, kBlockSize, 0, stream>>>(d_sorted_logits_in,
d_sorted_indices,
d_logits_in_out,
1 - top_p,
filter_value,
min_tokens_to_keep,
batch_size,
vocab_size);
}
}
@ -595,7 +596,6 @@ template void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
cudaStream_t stream,
bool is_descending);
// Ref: https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/MultinomialKernel.cu
template <typename scalar_t, typename accscalar_t>
__global__ void sampleMultinomialOnce(int64_t* dest,
@ -603,18 +603,17 @@ __global__ void sampleMultinomialOnce(int64_t* dest,
int categories,
scalar_t* sampled,
scalar_t* dist,
int stride_dist, // dist->stride(0)
int stride_categories, // dist->stride(1)
int stride_dist, // dist->stride(0)
int stride_categories, // dist->stride(1)
int* d_presence_mask) {
extern __shared__ unsigned char my_smem[];
extern __shared__ unsigned char my_smem[];
__shared__ bool found;
__shared__ unsigned foundPos;
accscalar_t *smem = reinterpret_cast<accscalar_t *>(my_smem);
accscalar_t* smem = reinterpret_cast<accscalar_t*>(my_smem);
accscalar_t accZero = static_cast<accscalar_t>(0);
scalar_t zero = static_cast<scalar_t>(0);
for (int curDist = blockIdx.x;
curDist < distributions; curDist += gridDim.x) {
// Assume sum = 1 in Top P sampling as the input is softmaxed.
accscalar_t sum = 1;
@ -644,9 +643,7 @@ __global__ void sampleMultinomialOnce(int64_t* dest,
for (int chunk = 0; chunk < chunks && !found; ++chunk) {
// All threads in bounds load a value
int cat = chunk * blockDim.x + threadIdx.x;
accscalar_t dist_val = cat < categories ?
static_cast<accscalar_t>(dist[curDist * stride_dist + cat * stride_categories]) / sum :
accZero;
accscalar_t dist_val = cat < categories ? static_cast<accscalar_t>(dist[curDist * stride_dist + cat * stride_categories]) / sum : accZero;
smem[threadIdx.x] = dist_val;
__syncthreads();
// Perform an inclusive prefix sum of the shared memory contents
@ -666,12 +663,12 @@ __global__ void sampleMultinomialOnce(int64_t* dest,
static_cast<scalar_t>(smem[threadIdx.x] + prevHighProb);
scalar_t prevBucket = static_cast<scalar_t>(
threadIdx.x == 0 ? prevHighProb
: smem[threadIdx.x - 1] + prevHighProb);
: smem[threadIdx.x - 1] + prevHighProb);
bool inBucket =
(cat < categories) &&
(!(sample >= curBucket) &&
(sample >= prevBucket) &&
(dist_val > zero));
(sample >= prevBucket) &&
(dist_val > zero));
if (inBucket) {
// We're done; we have the sample
// Torch indices are 1-based
@ -684,7 +681,7 @@ __global__ void sampleMultinomialOnce(int64_t* dest,
}
if (threadIdx.x == 0) {
if (found) {
dest[curDist] = foundPos;
dest[curDist] = foundPos;
} else {
// This should address a rare bug where we don't select a valid index. This likely occurs when
// due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but
@ -721,8 +718,7 @@ void TorchMultinomialKernelLauncher(float* d_input,
int batch_size,
int vocab_size,
int* d_presence_mask,
cudaStream_t stream)
{
cudaStream_t stream) {
// Store the props in class variables
int device;
cudaGetDevice(&device);
@ -731,7 +727,7 @@ void TorchMultinomialKernelLauncher(float* d_input,
int numSM = props.multiProcessorCount;
int maxThreads = props.maxThreadsPerBlock;
int warp_size = 32; //at::cuda::warp_size();
int warp_size = 32; // at::cuda::warp_size();
int requiredWarps = (vocab_size + warp_size - 1) / warp_size;
int requiredThreads = std::min(maxThreads, requiredWarps * warp_size);
int requiredShared = requiredThreads * sizeof(float);
@ -748,11 +744,8 @@ void TorchMultinomialKernelLauncher(float* d_input,
vocab_size,
1,
d_presence_mask);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime