From 7a9a6bcebd2fe987de5cb220efa0d377f1ab68f6 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Wed, 11 Jan 2023 08:40:17 -0800 Subject: [PATCH] Improve TopP sampling (#14192) ### Description Improve TopP sampling's filter kernel with cub::scan. It reduces TopP sampling latency from 3.67 to 0.92 for batch size 8 and vocabulary size 51k. --- .../cuda/transformers/generation_cuda_impl.cu | 245 +++++++++--------- 1 file changed, 119 insertions(+), 126 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index f35ef8a40b..bf98394d1c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -8,7 +8,6 @@ #include #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" - namespace onnxruntime { namespace contrib { namespace cuda { @@ -312,14 +311,14 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data, } template -void GetTempStorageSize(const T *d_keys_in, - const int* d_values_in, - int* d_offsets, - int num_items, - int num_segments, - cudaStream_t stream, - bool is_descending, - size_t& temp_storage_bytes) { +void GetTempStorageSize(const T* d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes) { if (is_descending) { cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, @@ -352,24 +351,24 @@ void GetTempStorageSize(const T *d_keys_in, } template void GetTempStorageSize( - const float *d_keys_in, - const int* d_values_in, - int* d_offsets, - int num_items, - int num_segments, - cudaStream_t stream, - bool is_descending, - size_t& temp_storage_bytes); + const float* d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes); template void GetTempStorageSize( - const half *d_keys_in, - const int* d_values_in, - int* d_offsets, - int num_items, - int num_segments, - cudaStream_t stream, - bool is_descending, - size_t& temp_storage_bytes); + const half* d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes); // TODO: merge to one kernel __global__ void SetupParamsKernel(int* d_values_in, @@ -401,17 +400,17 @@ void LaunchSetupParamsKernel(int* d_values_in, } template -void LaunchSortPairs(void *d_temp_storage, - size_t temp_storage_bytes, - const T *d_keys_in, - T *d_keys_out, - const int *d_values_in, - int *d_values_out, - int num_items, - int num_segments, - int *d_offsets, - cudaStream_t stream, - bool is_descending) { +void LaunchSortPairs(void* d_temp_storage, + size_t temp_storage_bytes, + const T* d_keys_in, + T* d_keys_out, + const int* d_values_in, + int* d_values_out, + int num_items, + int num_segments, + int* d_offsets, + cudaStream_t stream, + bool is_descending) { if (is_descending) { cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, @@ -443,31 +442,46 @@ void LaunchSortPairs(void *d_temp_storage, } } -template void LaunchSortPairs(void *d_temp_storage, +template void LaunchSortPairs(void* d_temp_storage, size_t temp_storage_bytes, - const float *d_keys_in, - float *d_keys_out, - const int *d_values_in, - int *d_values_out, + const float* d_keys_in, + float* d_keys_out, + const int* d_values_in, + int* d_values_out, int num_items, int num_segments, - int *d_offsets, + int* d_offsets, cudaStream_t stream, bool is_descending); -template void LaunchSortPairs(void *d_temp_storage, +template void LaunchSortPairs(void* d_temp_storage, size_t temp_storage_bytes, - const half *d_keys_in, - half *d_keys_out, - const int *d_values_in, - int *d_values_out, + const half* d_keys_in, + half* d_keys_out, + const int* d_values_in, + int* d_values_out, int num_items, int num_segments, - int *d_offsets, + int* d_offsets, cudaStream_t stream, bool is_descending); -template +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +struct BlockPrefixCallbackOp { + float running_total; // running prefix + + __device__ BlockPrefixCallbackOp(float running_total) : running_total(running_total) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ float operator()(float block_aggregate) { + float old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +template __global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in, const int* d_sorted_indices, T* d_logits_in_out, @@ -475,35 +489,27 @@ __global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in, float filter_value, int batch_size, int vocab_size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; + int vocab_idx = threadIdx.x; + int batch_id = blockIdx.x; + int offset = batch_id * vocab_size; - if (index >= batch_size * vocab_size) { - return; - } + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + BlockPrefixCallbackOp prefix_op(0); - int vocab_idx = index % vocab_size; - int batch_id = index / vocab_size; - int start_index = batch_id * vocab_size; + for (int idx = vocab_idx; idx < vocab_size; idx += kBlockSize) { + float sum = d_sorted_logits_in[offset + idx]; + BlockScan(temp_storage).ExclusiveSum(sum, sum, prefix_op); - int count = vocab_idx; - float sum = 0.0f; - while (count >= 0) { - sum += d_sorted_logits_in[start_index]; - ++start_index; - --count; - } - - if (sum > top_p_threshold) { - // Shift the indices to the right by one according to the custom implementation. - int shifted_index = index + 1; - if (shifted_index % vocab_size != 0) { - int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index]; + __syncthreads(); + if (sum >= top_p_threshold) { + int original_index = offset + d_sorted_indices[offset + idx]; d_logits_in_out[original_index] = (T)filter_value; } } } -template +template __global__ void FilterLogitsKernel(float* d_sorted_logits_in, const int* d_sorted_indices, T* d_logits_in_out, @@ -512,29 +518,25 @@ __global__ void FilterLogitsKernel(float* d_sorted_logits_in, int min_tokens_to_keep, int batch_size, int vocab_size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; + int vocab_idx = threadIdx.x; + int batch_id = blockIdx.x; + int offset = batch_id * vocab_size; - if (index >= batch_size * vocab_size) { - return; - } + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + BlockPrefixCallbackOp prefix_op(0); - int vocab_idx = index % vocab_size; - int batch_id = index / vocab_size; - int start_index = batch_id * vocab_size; + for (int idx = vocab_idx; idx < vocab_size; idx += kBlockSize) { + float sum = d_sorted_logits_in[offset + idx]; + BlockScan(temp_storage).InclusiveSum(sum, sum, prefix_op); - int count = vocab_idx; - float sum = 0.0f; - // TODO: Optimization needed. e.g. use CUB::SCAN() for cumulative probabilities. - while (count >= 0) { - sum += d_sorted_logits_in[start_index]; - ++start_index; - --count; - } + __syncthreads(); - if (sum <= top_p_threshold) { - if (index % vocab_size + min_tokens_to_keep < vocab_size) { - int original_index = batch_id * vocab_size + d_sorted_indices[index]; - d_logits_in_out[original_index] = (T)filter_value; + if (sum <= top_p_threshold) { + if (idx + min_tokens_to_keep < vocab_size) { + int original_index = offset + d_sorted_indices[offset + idx]; + d_logits_in_out[original_index] = (T)filter_value; + } } } } @@ -550,26 +552,25 @@ void LaunchFilterLogitsKernel(float* d_sorted_logits_in, int vocab_size, cudaStream_t stream, bool is_descending) { - int total_elements = batch_size * vocab_size; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; + constexpr int kBlockSize = 256; + if (is_descending) { - FilterLogitsKernelCustom<<>>(d_sorted_logits_in, - d_sorted_indices, - d_logits_in_out, - top_p, - filter_value, - batch_size, - vocab_size); + FilterLogitsKernelCustom<<>>(d_sorted_logits_in, + d_sorted_indices, + d_logits_in_out, + top_p, + filter_value, + batch_size, + vocab_size); } else { - FilterLogitsKernel<<>>(d_sorted_logits_in, - d_sorted_indices, - d_logits_in_out, - 1 - top_p, - filter_value, - min_tokens_to_keep, - batch_size, - vocab_size); + FilterLogitsKernel<<>>(d_sorted_logits_in, + d_sorted_indices, + d_logits_in_out, + 1 - top_p, + filter_value, + min_tokens_to_keep, + batch_size, + vocab_size); } } @@ -595,7 +596,6 @@ template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, cudaStream_t stream, bool is_descending); - // Ref: https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/MultinomialKernel.cu template __global__ void sampleMultinomialOnce(int64_t* dest, @@ -603,18 +603,17 @@ __global__ void sampleMultinomialOnce(int64_t* dest, int categories, scalar_t* sampled, scalar_t* dist, - int stride_dist, // dist->stride(0) - int stride_categories, // dist->stride(1) + int stride_dist, // dist->stride(0) + int stride_categories, // dist->stride(1) int* d_presence_mask) { - extern __shared__ unsigned char my_smem[]; + extern __shared__ unsigned char my_smem[]; __shared__ bool found; __shared__ unsigned foundPos; - accscalar_t *smem = reinterpret_cast(my_smem); + accscalar_t* smem = reinterpret_cast(my_smem); accscalar_t accZero = static_cast(0); scalar_t zero = static_cast(0); for (int curDist = blockIdx.x; curDist < distributions; curDist += gridDim.x) { - // Assume sum = 1 in Top P sampling as the input is softmaxed. accscalar_t sum = 1; @@ -644,9 +643,7 @@ __global__ void sampleMultinomialOnce(int64_t* dest, for (int chunk = 0; chunk < chunks && !found; ++chunk) { // All threads in bounds load a value int cat = chunk * blockDim.x + threadIdx.x; - accscalar_t dist_val = cat < categories ? - static_cast(dist[curDist * stride_dist + cat * stride_categories]) / sum : - accZero; + accscalar_t dist_val = cat < categories ? static_cast(dist[curDist * stride_dist + cat * stride_categories]) / sum : accZero; smem[threadIdx.x] = dist_val; __syncthreads(); // Perform an inclusive prefix sum of the shared memory contents @@ -666,12 +663,12 @@ __global__ void sampleMultinomialOnce(int64_t* dest, static_cast(smem[threadIdx.x] + prevHighProb); scalar_t prevBucket = static_cast( threadIdx.x == 0 ? prevHighProb - : smem[threadIdx.x - 1] + prevHighProb); + : smem[threadIdx.x - 1] + prevHighProb); bool inBucket = (cat < categories) && (!(sample >= curBucket) && - (sample >= prevBucket) && - (dist_val > zero)); + (sample >= prevBucket) && + (dist_val > zero)); if (inBucket) { // We're done; we have the sample // Torch indices are 1-based @@ -684,7 +681,7 @@ __global__ void sampleMultinomialOnce(int64_t* dest, } if (threadIdx.x == 0) { if (found) { - dest[curDist] = foundPos; + dest[curDist] = foundPos; } else { // This should address a rare bug where we don't select a valid index. This likely occurs when // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but @@ -721,8 +718,7 @@ void TorchMultinomialKernelLauncher(float* d_input, int batch_size, int vocab_size, int* d_presence_mask, - cudaStream_t stream) -{ + cudaStream_t stream) { // Store the props in class variables int device; cudaGetDevice(&device); @@ -731,7 +727,7 @@ void TorchMultinomialKernelLauncher(float* d_input, int numSM = props.multiProcessorCount; int maxThreads = props.maxThreadsPerBlock; - int warp_size = 32; //at::cuda::warp_size(); + int warp_size = 32; // at::cuda::warp_size(); int requiredWarps = (vocab_size + warp_size - 1) / warp_size; int requiredThreads = std::min(maxThreads, requiredWarps * warp_size); int requiredShared = requiredThreads * sizeof(float); @@ -748,11 +744,8 @@ void TorchMultinomialKernelLauncher(float* d_input, vocab_size, 1, d_presence_mask); - } - - } // namespace cuda } // namespace contrib } // namespace onnxruntime