mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Improve TopP sampling (#14192)
### Description Improve TopP sampling's filter kernel with cub::scan. It reduces TopP sampling latency from 3.67 to 0.92 for batch size 8 and vocabulary size 51k.
This commit is contained in:
parent
d92c663f28
commit
7a9a6bcebd
1 changed files with 119 additions and 126 deletions
|
|
@ -8,7 +8,6 @@
|
|||
#include <cub/device/device_segmented_radix_sort.cuh>
|
||||
#include "contrib_ops/cuda/transformers/generation_cuda_impl.h"
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
|
@ -312,14 +311,14 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void GetTempStorageSize(const T *d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes) {
|
||||
void GetTempStorageSize(const T* d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes) {
|
||||
if (is_descending) {
|
||||
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
|
||||
temp_storage_bytes,
|
||||
|
|
@ -352,24 +351,24 @@ void GetTempStorageSize(const T *d_keys_in,
|
|||
}
|
||||
|
||||
template void GetTempStorageSize(
|
||||
const float *d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes);
|
||||
const float* d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes);
|
||||
|
||||
template void GetTempStorageSize(
|
||||
const half *d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes);
|
||||
const half* d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes);
|
||||
|
||||
// TODO: merge to one kernel
|
||||
__global__ void SetupParamsKernel(int* d_values_in,
|
||||
|
|
@ -401,17 +400,17 @@ void LaunchSetupParamsKernel(int* d_values_in,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void LaunchSortPairs(void *d_temp_storage,
|
||||
size_t temp_storage_bytes,
|
||||
const T *d_keys_in,
|
||||
T *d_keys_out,
|
||||
const int *d_values_in,
|
||||
int *d_values_out,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
int *d_offsets,
|
||||
cudaStream_t stream,
|
||||
bool is_descending) {
|
||||
void LaunchSortPairs(void* d_temp_storage,
|
||||
size_t temp_storage_bytes,
|
||||
const T* d_keys_in,
|
||||
T* d_keys_out,
|
||||
const int* d_values_in,
|
||||
int* d_values_out,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
int* d_offsets,
|
||||
cudaStream_t stream,
|
||||
bool is_descending) {
|
||||
if (is_descending) {
|
||||
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
|
||||
temp_storage_bytes,
|
||||
|
|
@ -443,31 +442,46 @@ void LaunchSortPairs(void *d_temp_storage,
|
|||
}
|
||||
}
|
||||
|
||||
template void LaunchSortPairs(void *d_temp_storage,
|
||||
template void LaunchSortPairs(void* d_temp_storage,
|
||||
size_t temp_storage_bytes,
|
||||
const float *d_keys_in,
|
||||
float *d_keys_out,
|
||||
const int *d_values_in,
|
||||
int *d_values_out,
|
||||
const float* d_keys_in,
|
||||
float* d_keys_out,
|
||||
const int* d_values_in,
|
||||
int* d_values_out,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
int *d_offsets,
|
||||
int* d_offsets,
|
||||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
template void LaunchSortPairs(void *d_temp_storage,
|
||||
template void LaunchSortPairs(void* d_temp_storage,
|
||||
size_t temp_storage_bytes,
|
||||
const half *d_keys_in,
|
||||
half *d_keys_out,
|
||||
const int *d_values_in,
|
||||
int *d_values_out,
|
||||
const half* d_keys_in,
|
||||
half* d_keys_out,
|
||||
const int* d_values_in,
|
||||
int* d_values_out,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
int *d_offsets,
|
||||
int* d_offsets,
|
||||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
template <typename T>
|
||||
// A stateful callback functor that maintains a running prefix to be applied
|
||||
// during consecutive scan operations.
|
||||
struct BlockPrefixCallbackOp {
|
||||
float running_total; // running prefix
|
||||
|
||||
__device__ BlockPrefixCallbackOp(float running_total) : running_total(running_total) {}
|
||||
// Callback operator to be entered by the first warp of threads in the block.
|
||||
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
||||
__device__ float operator()(float block_aggregate) {
|
||||
float old_prefix = running_total;
|
||||
running_total += block_aggregate;
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int kBlockSize>
|
||||
__global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in,
|
||||
const int* d_sorted_indices,
|
||||
T* d_logits_in_out,
|
||||
|
|
@ -475,35 +489,27 @@ __global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in,
|
|||
float filter_value,
|
||||
int batch_size,
|
||||
int vocab_size) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int vocab_idx = threadIdx.x;
|
||||
int batch_id = blockIdx.x;
|
||||
int offset = batch_id * vocab_size;
|
||||
|
||||
if (index >= batch_size * vocab_size) {
|
||||
return;
|
||||
}
|
||||
typedef cub::BlockScan<float, kBlockSize> BlockScan;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
BlockPrefixCallbackOp prefix_op(0);
|
||||
|
||||
int vocab_idx = index % vocab_size;
|
||||
int batch_id = index / vocab_size;
|
||||
int start_index = batch_id * vocab_size;
|
||||
for (int idx = vocab_idx; idx < vocab_size; idx += kBlockSize) {
|
||||
float sum = d_sorted_logits_in[offset + idx];
|
||||
BlockScan(temp_storage).ExclusiveSum(sum, sum, prefix_op);
|
||||
|
||||
int count = vocab_idx;
|
||||
float sum = 0.0f;
|
||||
while (count >= 0) {
|
||||
sum += d_sorted_logits_in[start_index];
|
||||
++start_index;
|
||||
--count;
|
||||
}
|
||||
|
||||
if (sum > top_p_threshold) {
|
||||
// Shift the indices to the right by one according to the custom implementation.
|
||||
int shifted_index = index + 1;
|
||||
if (shifted_index % vocab_size != 0) {
|
||||
int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index];
|
||||
__syncthreads();
|
||||
if (sum >= top_p_threshold) {
|
||||
int original_index = offset + d_sorted_indices[offset + idx];
|
||||
d_logits_in_out[original_index] = (T)filter_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, int kBlockSize>
|
||||
__global__ void FilterLogitsKernel(float* d_sorted_logits_in,
|
||||
const int* d_sorted_indices,
|
||||
T* d_logits_in_out,
|
||||
|
|
@ -512,29 +518,25 @@ __global__ void FilterLogitsKernel(float* d_sorted_logits_in,
|
|||
int min_tokens_to_keep,
|
||||
int batch_size,
|
||||
int vocab_size) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int vocab_idx = threadIdx.x;
|
||||
int batch_id = blockIdx.x;
|
||||
int offset = batch_id * vocab_size;
|
||||
|
||||
if (index >= batch_size * vocab_size) {
|
||||
return;
|
||||
}
|
||||
typedef cub::BlockScan<float, kBlockSize> BlockScan;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
BlockPrefixCallbackOp prefix_op(0);
|
||||
|
||||
int vocab_idx = index % vocab_size;
|
||||
int batch_id = index / vocab_size;
|
||||
int start_index = batch_id * vocab_size;
|
||||
for (int idx = vocab_idx; idx < vocab_size; idx += kBlockSize) {
|
||||
float sum = d_sorted_logits_in[offset + idx];
|
||||
BlockScan(temp_storage).InclusiveSum(sum, sum, prefix_op);
|
||||
|
||||
int count = vocab_idx;
|
||||
float sum = 0.0f;
|
||||
// TODO: Optimization needed. e.g. use CUB::SCAN() for cumulative probabilities.
|
||||
while (count >= 0) {
|
||||
sum += d_sorted_logits_in[start_index];
|
||||
++start_index;
|
||||
--count;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (sum <= top_p_threshold) {
|
||||
if (index % vocab_size + min_tokens_to_keep < vocab_size) {
|
||||
int original_index = batch_id * vocab_size + d_sorted_indices[index];
|
||||
d_logits_in_out[original_index] = (T)filter_value;
|
||||
if (sum <= top_p_threshold) {
|
||||
if (idx + min_tokens_to_keep < vocab_size) {
|
||||
int original_index = offset + d_sorted_indices[offset + idx];
|
||||
d_logits_in_out[original_index] = (T)filter_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -550,26 +552,25 @@ void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
|
|||
int vocab_size,
|
||||
cudaStream_t stream,
|
||||
bool is_descending) {
|
||||
int total_elements = batch_size * vocab_size;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
constexpr int kBlockSize = 256;
|
||||
|
||||
if (is_descending) {
|
||||
FilterLogitsKernelCustom<<<gridSize, blockSize, 0, stream>>>(d_sorted_logits_in,
|
||||
d_sorted_indices,
|
||||
d_logits_in_out,
|
||||
top_p,
|
||||
filter_value,
|
||||
batch_size,
|
||||
vocab_size);
|
||||
FilterLogitsKernelCustom<T, kBlockSize><<<batch_size, kBlockSize, 0, stream>>>(d_sorted_logits_in,
|
||||
d_sorted_indices,
|
||||
d_logits_in_out,
|
||||
top_p,
|
||||
filter_value,
|
||||
batch_size,
|
||||
vocab_size);
|
||||
} else {
|
||||
FilterLogitsKernel<<<gridSize, blockSize, 0, stream>>>(d_sorted_logits_in,
|
||||
d_sorted_indices,
|
||||
d_logits_in_out,
|
||||
1 - top_p,
|
||||
filter_value,
|
||||
min_tokens_to_keep,
|
||||
batch_size,
|
||||
vocab_size);
|
||||
FilterLogitsKernel<T, kBlockSize><<<batch_size, kBlockSize, 0, stream>>>(d_sorted_logits_in,
|
||||
d_sorted_indices,
|
||||
d_logits_in_out,
|
||||
1 - top_p,
|
||||
filter_value,
|
||||
min_tokens_to_keep,
|
||||
batch_size,
|
||||
vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -595,7 +596,6 @@ template void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
|
|||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
|
||||
// Ref: https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/MultinomialKernel.cu
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
__global__ void sampleMultinomialOnce(int64_t* dest,
|
||||
|
|
@ -603,18 +603,17 @@ __global__ void sampleMultinomialOnce(int64_t* dest,
|
|||
int categories,
|
||||
scalar_t* sampled,
|
||||
scalar_t* dist,
|
||||
int stride_dist, // dist->stride(0)
|
||||
int stride_categories, // dist->stride(1)
|
||||
int stride_dist, // dist->stride(0)
|
||||
int stride_categories, // dist->stride(1)
|
||||
int* d_presence_mask) {
|
||||
extern __shared__ unsigned char my_smem[];
|
||||
extern __shared__ unsigned char my_smem[];
|
||||
__shared__ bool found;
|
||||
__shared__ unsigned foundPos;
|
||||
accscalar_t *smem = reinterpret_cast<accscalar_t *>(my_smem);
|
||||
accscalar_t* smem = reinterpret_cast<accscalar_t*>(my_smem);
|
||||
accscalar_t accZero = static_cast<accscalar_t>(0);
|
||||
scalar_t zero = static_cast<scalar_t>(0);
|
||||
for (int curDist = blockIdx.x;
|
||||
curDist < distributions; curDist += gridDim.x) {
|
||||
|
||||
// Assume sum = 1 in Top P sampling as the input is softmaxed.
|
||||
accscalar_t sum = 1;
|
||||
|
||||
|
|
@ -644,9 +643,7 @@ __global__ void sampleMultinomialOnce(int64_t* dest,
|
|||
for (int chunk = 0; chunk < chunks && !found; ++chunk) {
|
||||
// All threads in bounds load a value
|
||||
int cat = chunk * blockDim.x + threadIdx.x;
|
||||
accscalar_t dist_val = cat < categories ?
|
||||
static_cast<accscalar_t>(dist[curDist * stride_dist + cat * stride_categories]) / sum :
|
||||
accZero;
|
||||
accscalar_t dist_val = cat < categories ? static_cast<accscalar_t>(dist[curDist * stride_dist + cat * stride_categories]) / sum : accZero;
|
||||
smem[threadIdx.x] = dist_val;
|
||||
__syncthreads();
|
||||
// Perform an inclusive prefix sum of the shared memory contents
|
||||
|
|
@ -666,12 +663,12 @@ __global__ void sampleMultinomialOnce(int64_t* dest,
|
|||
static_cast<scalar_t>(smem[threadIdx.x] + prevHighProb);
|
||||
scalar_t prevBucket = static_cast<scalar_t>(
|
||||
threadIdx.x == 0 ? prevHighProb
|
||||
: smem[threadIdx.x - 1] + prevHighProb);
|
||||
: smem[threadIdx.x - 1] + prevHighProb);
|
||||
bool inBucket =
|
||||
(cat < categories) &&
|
||||
(!(sample >= curBucket) &&
|
||||
(sample >= prevBucket) &&
|
||||
(dist_val > zero));
|
||||
(sample >= prevBucket) &&
|
||||
(dist_val > zero));
|
||||
if (inBucket) {
|
||||
// We're done; we have the sample
|
||||
// Torch indices are 1-based
|
||||
|
|
@ -684,7 +681,7 @@ __global__ void sampleMultinomialOnce(int64_t* dest,
|
|||
}
|
||||
if (threadIdx.x == 0) {
|
||||
if (found) {
|
||||
dest[curDist] = foundPos;
|
||||
dest[curDist] = foundPos;
|
||||
} else {
|
||||
// This should address a rare bug where we don't select a valid index. This likely occurs when
|
||||
// due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but
|
||||
|
|
@ -721,8 +718,7 @@ void TorchMultinomialKernelLauncher(float* d_input,
|
|||
int batch_size,
|
||||
int vocab_size,
|
||||
int* d_presence_mask,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
cudaStream_t stream) {
|
||||
// Store the props in class variables
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
|
|
@ -731,7 +727,7 @@ void TorchMultinomialKernelLauncher(float* d_input,
|
|||
|
||||
int numSM = props.multiProcessorCount;
|
||||
int maxThreads = props.maxThreadsPerBlock;
|
||||
int warp_size = 32; //at::cuda::warp_size();
|
||||
int warp_size = 32; // at::cuda::warp_size();
|
||||
int requiredWarps = (vocab_size + warp_size - 1) / warp_size;
|
||||
int requiredThreads = std::min(maxThreads, requiredWarps * warp_size);
|
||||
int requiredShared = requiredThreads * sizeof(float);
|
||||
|
|
@ -748,11 +744,8 @@ void TorchMultinomialKernelLauncher(float* d_input,
|
|||
vocab_size,
|
||||
1,
|
||||
d_presence_mask);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue