mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
[BeamSearch]optimize key cache reordering (#17771)
### Description <!-- Describe your changes. --> Replace onnxruntime::cuda::Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim() with custom transpose kernel in ReorderPastState(). The original implementation doesn't benefit from vectorized loading and coalesced accessing(write). and not fully utilize threads in the block. benchmarked with TNLGv4 model(batch=4, seq_len=4K) transpose kernel speed up: ~1.9X (392 μs -> 206 μs) overall reordering speedup: ~1.48X Latency: before:  after:  GPU matrix: before:  after:  ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
e1a089c23c
commit
0e988239cc
3 changed files with 84 additions and 22 deletions
|
|
@ -1315,6 +1315,67 @@ template void BufferExpansionKernelLauncher(const int32_t* input,
|
|||
int chunk_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Support head_size up to 128
|
||||
constexpr unsigned int kTileSize = 32;
|
||||
constexpr unsigned int kSeqTileSize = 16;
|
||||
|
||||
__global__ void ReorderPastStatesKernel(float4* out_buffer,
|
||||
const float4* in_buffer,
|
||||
int batch_size,
|
||||
int num_heads,
|
||||
int max_length,
|
||||
int chunked_head_size) {
|
||||
__shared__ float4 tile[kSeqTileSize][kTileSize + 1];
|
||||
|
||||
const int b = blockIdx.z;
|
||||
const int n = blockIdx.y;
|
||||
const int s_base = blockIdx.x * kSeqTileSize;
|
||||
const int s = s_base + threadIdx.y;
|
||||
const int base_offset = (b * num_heads + n) * max_length * chunked_head_size;
|
||||
|
||||
if (s < max_length) {
|
||||
const int in_offset = base_offset + s * chunked_head_size + threadIdx.x;
|
||||
tile[threadIdx.y][threadIdx.x] = in_buffer[in_offset];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int tidx = threadIdx.x + threadIdx.y * chunked_head_size;
|
||||
const int tidx_x = tidx % kSeqTileSize;
|
||||
const int tidx_y = tidx / kSeqTileSize;
|
||||
|
||||
const int s2 = s_base + tidx_x;
|
||||
|
||||
if (s2 < max_length) {
|
||||
const int out_offset = base_offset + tidx_y * max_length + s2;
|
||||
out_buffer[out_offset] = tile[tidx_x][tidx_y];
|
||||
}
|
||||
}
|
||||
|
||||
void ReorderPastStatesKernelLauncher(void* out_buffer,
|
||||
const void* in_buffer,
|
||||
int batch_size,
|
||||
int num_heads,
|
||||
int max_length,
|
||||
int head_size,
|
||||
int chunk_size,
|
||||
cudaStream_t stream) {
|
||||
//[B, N, max_length, H2(head_size/chunk_size), equv_chunk_size] -> [B, N, H2(head_size/chunk_size), max_length, equv_chunk_size]
|
||||
const int chunked_head_size = head_size / chunk_size;
|
||||
const dim3 block(chunked_head_size, kSeqTileSize);
|
||||
const dim3 grid((max_length + kSeqTileSize - 1) / kSeqTileSize, num_heads, batch_size);
|
||||
if (chunk_size == 4 || chunk_size == 8) {
|
||||
ReorderPastStatesKernel<<<grid, block, 0, stream>>>(reinterpret_cast<float4*>(out_buffer),
|
||||
reinterpret_cast<const float4*>(in_buffer),
|
||||
batch_size,
|
||||
num_heads,
|
||||
max_length,
|
||||
chunked_head_size);
|
||||
} else {
|
||||
ORT_THROW("ReorderPastStatesKernelLauncher only support float or half");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -213,6 +213,14 @@ void BufferExpansionKernelLauncher(const T* input,
|
|||
int chunk_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
void ReorderPastStatesKernelLauncher(void* out_buffer,
|
||||
const void* in_buffer,
|
||||
int batch_size,
|
||||
int num_heads,
|
||||
int max_length,
|
||||
int head_size,
|
||||
int chunk_size,
|
||||
cudaStream_t stream);
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -56,19 +56,23 @@ namespace GenerationCudaDeviceHelper {
|
|||
// It might be better to forcefully require the same type since cast node generates
|
||||
// extra overhead.
|
||||
Status ReorderPastState(
|
||||
const void* cuda_device_prop,
|
||||
const void*,
|
||||
Tensor& past_state,
|
||||
Tensor& past_state_staging,
|
||||
Stream* stream) {
|
||||
ORT_ENFORCE(stream);
|
||||
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream->GetHandle());
|
||||
cublasHandle_t cublas_handle = static_cast<CudaStream*>(stream)->cublas_handle_;
|
||||
|
||||
const auto& past_state_shape = past_state.Shape();
|
||||
|
||||
const auto& past_state_dims = past_state_shape.GetDims();
|
||||
const bool packed_past = past_state_dims.size() == 5;
|
||||
|
||||
size_t batch_size = packed_past ? past_state_dims[1] : past_state_dims[0];
|
||||
size_t num_heads = packed_past ? past_state_dims[2] : past_state_dims[1];
|
||||
size_t max_length = packed_past ? past_state_dims[3] : past_state_dims[2];
|
||||
size_t head_size = packed_past ? past_state_dims[4] : past_state_dims[3];
|
||||
|
||||
// Copy the 'K' values into the temp staging buffer
|
||||
size_t past_state_size = packed_past ? past_state.SizeInBytes() / 2 : past_state.SizeInBytes();
|
||||
void* past_state_staging_buffer = past_state_staging.MutableDataRaw();
|
||||
|
|
@ -79,27 +83,16 @@ Status ReorderPastState(
|
|||
// [B, N, head_size / x, max_length, x], where x = 16 / sizeof(T)
|
||||
int64_t chunk_size = static_cast<int64_t>(16 / past_state.DataType()->Size());
|
||||
|
||||
std::vector<size_t> permutation_vector = {0, 1, 3, 2, 4};
|
||||
gsl::span<size_t> permutation(permutation_vector.data(), 5);
|
||||
cuda::ReorderPastStatesKernelLauncher(past_state.MutableDataRaw(),
|
||||
past_state_staging_buffer,
|
||||
static_cast<int>(batch_size),
|
||||
static_cast<int>(num_heads),
|
||||
static_cast<int>(max_length),
|
||||
static_cast<int>(head_size),
|
||||
static_cast<int>(chunk_size),
|
||||
cuda_stream);
|
||||
|
||||
// "Fake" the shapes of the input and output tensors of the Transpose operation to suit our need
|
||||
size_t offset = packed_past ? 1 : 0;
|
||||
TensorShape transpose_input_shape_override = {past_state_shape[offset],
|
||||
past_state_shape[offset + 1],
|
||||
past_state_shape[offset + 2],
|
||||
past_state_shape[offset + 3] / chunk_size,
|
||||
chunk_size};
|
||||
|
||||
TensorShape transpose_output_shape_override = {past_state_shape[offset], past_state_shape[offset + 1],
|
||||
past_state_shape[offset + 3] / chunk_size, past_state_shape[offset + 2],
|
||||
chunk_size};
|
||||
|
||||
// TODO(hasesh): Explore perf tuning for this Transpose operation
|
||||
return onnxruntime::cuda::Transpose::DoTranspose(*static_cast<const cudaDeviceProp*>(cuda_device_prop), cuda_stream,
|
||||
cublas_handle, permutation,
|
||||
past_state_staging, past_state,
|
||||
&transpose_input_shape_override,
|
||||
&transpose_output_shape_override);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InitCacheIndir(Tensor& cache_indir, Stream* stream) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue